dymad.training.ls_update

Module Attributes

SOL_MAP

Mapping of linear solver methods

Functions

check_linear_impl(model)

Check if the model implements linear features and eval methods.

check_linear_solve(model)

Check if the model implements linear_solve method.

get_batch_ct(dataloader, model, dt, **kwargs)

get_batch_dt(dataloader, model, dt, **kwargs)

Classes

LSUpdater(method, model[, dt, params])

Update linear weights by least squares.

class dymad.training.ls_update.LSUpdater(method, model, dt=None, params=None, **kwargs)

Bases: object

Update linear weights by least squares.

eval_batch(model, batch, criterion)

Process a batch and return predictions and ground truth states.

Only used in evaluation in this Trainer.

Return type:

Tensor

update(model, dataloader)

Train the model for one epoch.

Return type:

tuple[float, Any]

dymad.training.ls_update.SOL_MAP = {'ct_full': <function _ct_full_der>, 'ct_full_log': <function _ct_full_log>, 'ct_raw': <function _ct_raw>, 'ct_sako_log': <function _ct_sako_log>, 'ct_truncated': <function _ct_truncated_der>, 'ct_truncated_log': <function _ct_truncated_log>, 'dt_full': <function _dt_full>, 'dt_raw': <function _dt_raw>, 'dt_sako': <function _dt_sako>, 'dt_truncated': <function _dt_truncated>}

Mapping of linear solver methods

dymad.training.ls_update.check_linear_impl(model)

Check if the model implements linear features and eval methods.

Technically we should check linear_eval and set_linear_weights as well.

Return type:

bool

dymad.training.ls_update.check_linear_solve(model)

Check if the model implements linear_solve method.

Return type:

bool

dymad.training.ls_update.get_batch_ct(dataloader, model, dt, **kwargs)
Return type:

tuple[ndarray, ndarray]

dymad.training.ls_update.get_batch_dt(dataloader, model, dt, **kwargs)
Return type:

tuple[ndarray, ndarray]