dymad.training.ls_update¶
Module Attributes
Mapping of linear solver methods |
Functions
|
Check if the model implements linear features and eval methods. |
|
Check if the model implements linear_solve method. |
|
|
|
Classes
|
Update linear weights by least squares. |
- class dymad.training.ls_update.LSUpdater(method, model, dt=None, params=None, **kwargs)¶
Bases:
objectUpdate linear weights by least squares.
- eval_batch(model, batch, criterion)¶
Process a batch and return predictions and ground truth states.
Only used in evaluation in this Trainer.
- Return type:
Tensor
- update(model, dataloader)¶
Train the model for one epoch.
- Return type:
tuple[float,Any]
- dymad.training.ls_update.SOL_MAP = {'ct_full': <function _ct_full_der>, 'ct_full_log': <function _ct_full_log>, 'ct_raw': <function _ct_raw>, 'ct_sako_log': <function _ct_sako_log>, 'ct_truncated': <function _ct_truncated_der>, 'ct_truncated_log': <function _ct_truncated_log>, 'dt_full': <function _dt_full>, 'dt_raw': <function _dt_raw>, 'dt_sako': <function _dt_sako>, 'dt_truncated': <function _dt_truncated>}¶
Mapping of linear solver methods
- dymad.training.ls_update.check_linear_impl(model)¶
Check if the model implements linear features and eval methods.
Technically we should check linear_eval and set_linear_weights as well.
- Return type:
bool
- dymad.training.ls_update.check_linear_solve(model)¶
Check if the model implements linear_solve method.
- Return type:
bool
- dymad.training.ls_update.get_batch_ct(dataloader, model, dt, **kwargs)¶
- Return type:
tuple[ndarray,ndarray]
- dymad.training.ls_update.get_batch_dt(dataloader, model, dt, **kwargs)¶
- Return type:
tuple[ndarray,ndarray]