dymad.training.phase_runtime

Functions

build_initial_trainer_state(config, *, ...)

Classes

ArtifactRegistry([_artifacts])

Typed intermediate artifacts shared across phases.

EvaluationArtifact([metrics, split, ...])

ExportArtifact([outputs])

LinearSolveRecord(phase_name, method, loss)

LinearSolveReportArtifact([records])

ModelArtifact(model, config, train_md, ...)

OptimizerStateArtifact(optimizer[, ...])

PhaseContext([train_set, valid_set, ...])

Live phase context for one run.

PhaseRecord(name, kind, started_epoch, ...)

PhaseResult(name, kind, trainer_state, ...)

Typed phase outcome.

TrainerState(config[, execution_services, ...])

Checkpointable training state.

TrainingHistoryArtifact([hist, crit, ...])

Exceptions

TrainingCheckpointError

Raised when a typed training checkpoint cannot be loaded.

class dymad.training.phase_runtime.ArtifactRegistry(_artifacts=<factory>)

Bases: object

Typed intermediate artifacts shared across phases.

checkpoint_payload()
Return type:

dict[str, Any]

classmethod from_checkpoint_payload(payload)
Return type:

ArtifactRegistry

get(key, default=None)
Return type:

Any

keys()
Return type:

Iterable[str]

put(key, artifact)
Return type:

Any

require(key, expected_type=None)
Return type:

Any

class dymad.training.phase_runtime.EvaluationArtifact(metrics=<factory>, split='valid', criterion_name='total')

Bases: object

criterion_name: str = 'total'
metrics: dict[str, float]
split: str = 'valid'
class dymad.training.phase_runtime.ExportArtifact(outputs=<factory>)

Bases: object

outputs: dict[str, str]
class dymad.training.phase_runtime.LinearSolveRecord(phase_name, method, loss, updated_parameters=<factory>)

Bases: object

loss: float
method: str
phase_name: str
updated_parameters: list[str]
class dymad.training.phase_runtime.LinearSolveReportArtifact(records=<factory>)

Bases: object

records: list[LinearSolveRecord]
class dymad.training.phase_runtime.ModelArtifact(model, config, train_md, valid_md, dtype)

Bases: object

config: dict[str, Any]
dtype: dtype
model: Module
train_md: dict[str, Any]
valid_md: dict[str, Any]
class dymad.training.phase_runtime.OptimizerStateArtifact(optimizer, schedulers=<factory>, criteria=<factory>, criteria_weights=<factory>, criteria_names=<factory>, owner_phase='', _weak_C=None, _weak_D=None, _weak_N=None, _weak_dN=None, _linear_updater=None, _one_step_dt=None, _one_step_kwargs=<factory>)

Bases: object

criteria: list[Module]
criteria_names: list[str]
criteria_weights: list[float]
optimizer: Optimizer
owner_phase: str = ''
schedulers: list[Any]
class dymad.training.phase_runtime.PhaseContext(train_set=None, valid_set=None, train_loader=None, valid_loader=None, train_md=None, valid_md=None)

Bases: object

Live phase context for one run.

train_loader: Optional[DataLoader[TypeAliasType]] = None
train_md: dict[str, Any] | None = None
train_set: list[TypeAliasType] | None = None
valid_loader: Optional[DataLoader[TypeAliasType]] = None
valid_md: dict[str, Any] | None = None
valid_set: list[TypeAliasType] | None = None
class dymad.training.phase_runtime.PhaseRecord(name, kind, started_epoch, completed_epoch, metrics=<factory>, artifact_keys=<factory>)

Bases: object

artifact_keys: list[str]
completed_epoch: int
kind: str
metrics: dict[str, float]
name: str
started_epoch: int
class dymad.training.phase_runtime.PhaseResult(name, kind, trainer_state, phase_context, artifacts, metrics=<factory>, record=None)

Bases: object

Typed phase outcome.

artifacts: ArtifactRegistry
get_metric(metric_name)
Return type:

float

kind: str
metrics: dict[str, float]
name: str
phase_context: PhaseContext
record: PhaseRecord | None = None
trainer_state: TrainerState
class dymad.training.phase_runtime.TrainerState(config, execution_services=None, device=None, epoch=0, best_loss=<factory>, converged=False, convergence_epoch=None, phase_cursor=0, phase_records=<factory>)

Bases: object

Checkpointable training state.

best_loss: dict[str, float]
checkpoint_payload()
Return type:

dict[str, Any]

config: dict[str, Any] | None
converged: bool = False
convergence_epoch: int | None = None
device: device | None = None
epoch: int = 0
execution_services: ExecutionServices | None = None
classmethod from_checkpoint_payload(payload, *, execution_services=None)
Return type:

TrainerState

phase_cursor: int = 0
phase_records: list[PhaseRecord]
exception dymad.training.phase_runtime.TrainingCheckpointError

Bases: ValueError

Raised when a typed training checkpoint cannot be loaded.

class dymad.training.phase_runtime.TrainingHistoryArtifact(hist=<factory>, crit=<factory>, epoch_times=<factory>, best_loss=<factory>, best_model_state_dict=None, convergence_epoch=None)

Bases: object

best_loss: dict[str, float]
best_model_state_dict: dict[str, Any] | None = None
convergence_epoch: int | None = None
crit: list[Any]
epoch_times: list[float]
hist: list[Any]
dymad.training.phase_runtime.build_initial_trainer_state(config, *, execution_services)
Return type:

TrainerState