dymad.modules.misc

Classes

TakeFirst(m)

Pass-through layer that returns the first m entries in the last axis.

TakeFirstGraph(m)

Graph version of TakeFirst.

class dymad.modules.misc.TakeFirst(m)

Bases: Module

Pass-through layer that returns the first m entries in the last axis.

Parameters:

m (int) – Number of entries to take from the last axis.

diagnostic_info()
Return type:

str

forward(x)
Return type:

Tensor

class dymad.modules.misc.TakeFirstGraph(m)

Bases: TakeFirst

Graph version of TakeFirst.

Input (…, n_nodes, n_features) Output (…, n_nodes*m)

forward(x, edge_index, edge_weights, edge_attr, **kwargs)
Return type:

Tensor