dymad.losses.losses

Module Attributes

LOSS_MAP

Mapping of loss names to loss classes.

Functions

vpt_loss(predictions, targets[, gamma])

Exact version of VPT loss (not differentiable)

wmse_loss(predictions, targets[, alpha])

Classes

VPTLoss([gamma, scl])

Valid Prediction Time Loss

WMSELoss([alpha])

Weighted Mean Squared Error Loss

dymad.losses.losses.LOSS_MAP = {'mae': <class 'torch.nn.modules.loss.L1Loss'>, 'mse': <class 'torch.nn.modules.loss.MSELoss'>, 'vpt': <class 'dymad.losses.losses.VPTLoss'>, 'wmse': <class 'dymad.losses.losses.WMSELoss'>}

Mapping of loss names to loss classes.

class dymad.losses.losses.VPTLoss(gamma=0.1, scl=10.0)

Bases: Module

Valid Prediction Time Loss

The Valid Prediction Time is the time until the prediction error exceeds a threshold.

Specifically, at step k, for each dimension i, the error is

E_{k,i}=(x_{k,i} - hat{x}_{k,i})^2 / std(x_i)^2,

where std(x_i) is the standard deviation of the single trajectory in dimension i. The total error at step k, E_k, is the RMSE of E_{k,i} over all dimensions i. The VPT is defined as the largest step index k such that E_k < gamma.

For training, we estimate k by softmax, average the VPT over trajectories, and minimize the loss defined as 1/VPT.

forward(predictions, targets)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

gamma: torch.Tensor
scl: torch.Tensor
training: bool
class dymad.losses.losses.WMSELoss(alpha=0)

Bases: Module

Weighted Mean Squared Error Loss

At step i, the loss is defined as: w_i(x_i-hat{x}_i)^2, where w_i is the weight for step i.

Currently, an exponential weighting is used; let v_i = exp(-alpha*i), then w_i = v_i / sum_j v_j.

When alpha=0, this reduces to the standard MSE loss. Note that alpha can be both positive and negative, favoring early or late steps

alpha: torch.Tensor
forward(predictions, targets)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

training: bool
dymad.losses.losses.vpt_loss(predictions, targets, gamma=0.1)

Exact version of VPT loss (not differentiable)

Return type:

tuple[Tensor, Tensor]

dymad.losses.losses.wmse_loss(predictions, targets, alpha=0.0)
Return type:

Tensor