dymad.losses

class dymad.losses.VPTLoss(gamma=0.1, scl=10.0)

Bases: Module

Valid Prediction Time Loss

The Valid Prediction Time is the time until the prediction error exceeds a threshold.

Specifically, at step k, for each dimension i, the error is

E_{k,i}=(x_{k,i} - hat{x}_{k,i})^2 / std(x_i)^2,

where std(x_i) is the standard deviation of the single trajectory in dimension i. The total error at step k, E_k, is the RMSE of E_{k,i} over all dimensions i. The VPT is defined as the largest step index k such that E_k < gamma.

For training, we estimate k by softmax, average the VPT over trajectories, and minimize the loss defined as 1/VPT.

forward(predictions, targets)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

gamma: torch.Tensor
scl: torch.Tensor
class dymad.losses.WMSELoss(alpha=0)

Bases: Module

Weighted Mean Squared Error Loss

At step i, the loss is defined as: w_i(x_i-hat{x}_i)^2, where w_i is the weight for step i.

Currently, an exponential weighting is used; let v_i = exp(-alpha*i), then w_i = v_i / sum_j v_j.

When alpha=0, this reduces to the standard MSE loss. Note that alpha can be both positive and negative, favoring early or late steps

alpha: torch.Tensor
forward(predictions, targets)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor

dymad.losses.vpt_loss(predictions, targets, gamma=0.1)

Exact version of VPT loss (not differentiable)

Return type:

tuple[Tensor, Tensor]

dymad.losses.wmse_loss(predictions, targets, alpha=0.0)
Return type:

Tensor

Modules