Gate

class e3nn.nn.Activation(irreps_in, acts)[source]

Bases: Module

Scalar activation function.

Odd scalar inputs require activation functions with a defined parity (odd or even).

Parameters:
  • irreps_in (e3nn.o3.Irreps) – representation of the input

  • acts (list of function or None) – list of activation functions, None if non-scalar or identity

Examples

>>> a = Activation("256x0o", [torch.abs])
>>> a.irreps_out
256x0e
>>> a = Activation("256x0o+16x1e", [None, None])
>>> a.irreps_out
256x0o+16x1e

Methods:

forward(features[, dim])

evaluate

forward(features, dim: int = -1)[source]

evaluate

Parameters:

features (torch.Tensor) – tensor of shape (...)

Returns:

tensor of shape the same shape as the input

Return type:

torch.Tensor

class e3nn.nn.Gate(irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated)[source]

Bases: Module

Gate activation function.

The gate activation is a direct sum of two sets of irreps. The first set of irreps is irreps_scalars passed through activation functions act_scalars. The second set of irreps is irreps_gated multiplied by the scalars irreps_gates passed through activation functions act_gates. Mathematically, this can be written as:

\[\left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right)\]

where \(x_i\) and \(\phi_i\) are from irreps_scalars and act_scalars, and \(g_j\), \(\phi_j\), and \(y_j\) are from irreps_gates, act_gates, and irreps_gated.

The parameters passed in should adhere to the following conditions:

  1. len(irreps_scalars) == len(act_scalars).

  2. len(irreps_gates) == len(act_gates).

  3. irreps_gates.num_irreps == irreps_gated.num_irreps.

Parameters:
  • irreps_scalars (e3nn.o3.Irreps) – Representation of the scalars that will be passed through the activation functions act_scalars.

  • act_scalars (list of function or None) – Activation functions acting on the scalars.

  • irreps_gates (e3nn.o3.Irreps) – Representation of the scalars that will be passed through the activation functions act_gates and multiplied by the irreps_gated.

  • act_gates (list of function or None) – Activation functions acting on the gates. The number of functions in the list should match the number of irrep groups in irreps_gates.

  • irreps_gated (e3nn.o3.Irreps) – Representation of the gated tensors. irreps_gates.num_irreps == irreps_gated.num_irreps

Examples

>>> g = Gate("16x0o", [torch.tanh], "32x0o", [torch.tanh], "16x1e+16x1o")
>>> g.irreps_out
16x0o+16x1o+16x1e

Methods:

forward(features)

Evaluate the gated activation function.

Attributes:

irreps_in

Input representations.

irreps_out

Output representations.

forward(features)[source]

Evaluate the gated activation function.

Parameters:

features (torch.Tensor) – tensor of shape (..., irreps_in.dim)

Returns:

tensor of shape (..., irreps_out.dim)

Return type:

torch.Tensor

property irreps_in[source]

Input representations.

property irreps_out[source]

Output representations.