dymad.losses.losses¶
Module Attributes
Mapping of loss names to loss classes. |
Functions
|
Exact version of VPT loss (not differentiable) |
|
Classes
- dymad.losses.losses.LOSS_MAP = {'mae': <class 'torch.nn.modules.loss.L1Loss'>, 'mse': <class 'torch.nn.modules.loss.MSELoss'>, 'vpt': <class 'dymad.losses.losses.VPTLoss'>, 'wmse': <class 'dymad.losses.losses.WMSELoss'>}¶
Mapping of loss names to loss classes.
- class dymad.losses.losses.VPTLoss(gamma=0.1, scl=10.0)¶
Bases:
ModuleValid Prediction Time Loss
The Valid Prediction Time is the time until the prediction error exceeds a threshold.
Specifically, at step k, for each dimension i, the error is
E_{k,i}=(x_{k,i} - hat{x}_{k,i})^2 / std(x_i)^2,
where std(x_i) is the standard deviation of the single trajectory in dimension i. The total error at step k, E_k, is the RMSE of E_{k,i} over all dimensions i. The VPT is defined as the largest step index k such that E_k < gamma.
For training, we estimate k by softmax, average the VPT over trajectories, and minimize the loss defined as 1/VPT.
- forward(predictions, targets)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
Tensor
- gamma: torch.Tensor¶
- scl: torch.Tensor¶
- training: bool¶
- class dymad.losses.losses.WMSELoss(alpha=0)¶
Bases:
ModuleWeighted Mean Squared Error Loss
At step i, the loss is defined as: w_i(x_i-hat{x}_i)^2, where w_i is the weight for step i.
Currently, an exponential weighting is used; let v_i = exp(-alpha*i), then w_i = v_i / sum_j v_j.
When alpha=0, this reduces to the standard MSE loss. Note that alpha can be both positive and negative, favoring early or late steps
- alpha: torch.Tensor¶
- forward(predictions, targets)¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
Tensor
- training: bool¶
- dymad.losses.losses.vpt_loss(predictions, targets, gamma=0.1)¶
Exact version of VPT loss (not differentiable)
- Return type:
tuple[Tensor,Tensor]
- dymad.losses.losses.wmse_loss(predictions, targets, alpha=0.0)¶
- Return type:
Tensor