dymad.io.checkpoint

Functions

graph_data_prep(data, nnd)

load_model(model_class, checkpoint_path, *)

Load a model from a checkpoint and optionally record the boundary plan.

visualize_model([mdl_class, ...])

Classes

BoundaryLoadTrace(plan, model_ref)

DataInterface([model_class, ...])

Interface for data transforms, possibly with learned autoencoders.

class dymad.io.checkpoint.BoundaryLoadTrace(plan, model_ref)

Bases: object

model_ref: str
plan: PredictionWorkflowPlan
class dymad.io.checkpoint.DataInterface(model_class=None, checkpoint_path=None, config_path=None, config_mod=None, device=None)

Bases: object

Interface for data transforms, possibly with learned autoencoders.

It loads the model (if available) and data, sets up the necessary transformations, and provides methods to encode, decode, and apply observables.

Cases:

  • [Priority] checkpoint_path is given: Load the data transforms and model from the checkpoint. May contain autoencoders.

  • [Secondary] config_path and/or config_mod is given: Instantiate the data transforms from the config. No model (i.e., autoencoders) in this case.

apply_obs(fobs)

Apply a generic observable to the raw data.

Parameters:

fobs (Callable) – Observable function. It should accept a 2D array input with each row as one step. The output should be a 1D array, whose ith entry corresponds to the ith step.

Return type:

ndarray

decode(X, rng=None)

Decode trajectory data from the observer space.

Return type:

ndarray

encode(X, rng=None)

Encode new trajectory data to the observer space.

Return type:

ndarray

get_backward_modes(ref=None, rng=None, **kwargs)
Return type:

ndarray

get_forward_modes(ref=None, rng=None, **kwargs)
Return type:

ndarray

dymad.io.checkpoint.graph_data_prep(data, nnd)
dymad.io.checkpoint.load_model(model_class, checkpoint_path, *, context=None, horizon=1, has_control=False, has_graph=False, return_trace=False)

Load a model from a checkpoint and optionally record the boundary plan.

dymad.io.checkpoint.visualize_model(mdl_class=None, checkpoint_path=None, model=None, prd_func=None, ref_data=None, depth=1, device='cpu', ifsave=False, show_all_paths=False)