orbiter.train.trainer
Base Trainer ABC for agent fine-tuning with multi-phase lifecycle.
Base Trainer ABC for agent fine-tuning with multi-phase lifecycle.
from orbiter.train.trainer import (
TrainConfig,
Trainer,
TrainerError,
TrainerState,
TrainMetrics,
)TrainerError
class TrainerError(Exception)Error during training operations.
TrainerState
class TrainerState(StrEnum)Trainer lifecycle state.
| Value | Description |
|---|---|
CREATED = "created" | Trainer instantiated, not yet validated |
VALIDATED = "validated" | All validation checks passed |
TRAINING = "training" | Training in progress |
COMPLETED = "completed" | Training completed successfully |
FAILED = "failed" | Training failed |
TrainMetrics
@dataclass(slots=True)
class TrainMetricsMetrics produced by training or evaluation.
| Field | Type | Default | Description |
|---|---|---|---|
loss | float | 0.0 | Training loss |
accuracy | float | 0.0 | Accuracy metric |
steps | int | 0 | Number of training steps |
extra | dict[str, Any] | {} | Backend-specific metrics |
TrainConfig
@dataclass(slots=True)
class TrainConfigMinimal training configuration. Subclass or extend via extra for backend-specific settings.
| Field | Type | Default | Description |
|---|---|---|---|
epochs | int | 1 | Number of training epochs |
batch_size | int | 8 | Batch size |
learning_rate | float | 1e-5 | Learning rate |
output_dir | str | "" | Output directory for checkpoints |
extra | dict[str, Any] | {} | Backend-specific settings |
Trainer
class Trainer(ABC)(config: TrainConfig | None = None)Abstract base class for agent training frameworks.
The lifecycle follows a strict phase ordering:
- Call
check_agent,check_dataset,check_reward,check_config(in any order) to validate all inputs. - Call
mark_validated()when all checks pass. - Call
train()to execute training. - Call
evaluate()to run evaluation on a test set.
Constructor parameters
| Name | Type | Default | Description |
|---|---|---|---|
config | TrainConfig | None | None | Training configuration. Defaults to TrainConfig() |
Properties
| Property | Type | Description |
|---|---|---|
state | TrainerState | Current lifecycle state |
config | TrainConfig | Training configuration |
Abstract methods
check_agent
def check_agent(self, agent: Any) -> NoneValidate that agent meets training requirements.
Raises: TrainerError — If validation fails.
check_dataset
def check_dataset(
self,
train_data: Any,
test_data: Any | None = None,
) -> NoneValidate training (and optional test) data.
Raises: TrainerError — If validation fails.
check_reward
def check_reward(self, reward_fn: Any | None = None) -> NoneValidate reward function or scoring mechanism.
Raises: TrainerError — If validation fails.
check_config
def check_config(
self,
config: TrainConfig | dict[str, Any] | None = None,
) -> NoneValidate and optionally update training configuration.
Raises: TrainerError — If configuration is invalid.
train
async def train(self) -> TrainMetricsExecute the training loop. Must be called after mark_validated().
Returns: Training metrics.
Raises: TrainerError — If training fails or trainer not validated.
evaluate
async def evaluate(self, test_data: Any | None = None) -> TrainMetricsRun evaluation on test data.
| Name | Type | Default | Description |
|---|---|---|---|
test_data | Any | None | None | Optional test dataset. If None, use data from check_dataset |
Returns: Evaluation metrics.
Raises: TrainerError — If evaluation fails.
Concrete methods
mark_validated
def mark_validated(self) -> NoneTransition to VALIDATED state after all checks pass.
Raises: TrainerError — If not in CREATED state.
Example
from orbiter.train import TrainConfig
# See VeRLTrainer for a concrete implementation
config = TrainConfig(epochs=3, batch_size=16, learning_rate=2e-5)