# Gate

class e3nn.nn.Activation(irreps_in, acts)[source]

Bases: torch.nn.modules.module.Module

Scalar activation function.

Odd scalar inputs require activation functions with a defined parity (odd or even).

Parameters

Examples

>>> a = Activation("256x0o", [torch.abs])
>>> a.irreps_out
256x0e

>>> a = Activation("256x0o+16x1e", [None, None])
>>> a.irreps_out
256x0o+16x1e


Methods:

 forward(features[, dim]) evaluate
forward(features, dim=- 1)[source]

evaluate

Parameters

features (torch.Tensor) – tensor of shape (...)

Returns

tensor of shape the same shape as the input

Return type

torch.Tensor

class e3nn.nn.Gate(irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated)[source]

Bases: torch.nn.modules.module.Module

Gate activation function.

The gate activation is a direct sum of two sets of irreps. The first set of irreps is irreps_scalars passed through activation functions act_scalars. The second set of irreps is irreps_gated multiplied by the scalars irreps_gates passed through activation functions act_gates. Mathematically, this can be written as:

$\left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right)$

where $$x_i$$ and $$\phi_i$$ are from irreps_scalars and act_scalars, and $$g_j$$, $$\phi_j$$, and $$y_j$$ are from irreps_gates, act_gates, and irreps_gated.

The parameters passed in should adhere to the following conditions:

1. len(irreps_scalars) == len(act_scalars).

2. len(irreps_gates) == len(act_gates).

3. irreps_gates.num_irreps == irreps_gated.num_irreps.

Parameters
• irreps_scalars (e3nn.o3.Irreps) – Representation of the scalars that will be passed through the activation functions act_scalars.

• act_scalars (list of function or None) – Activation functions acting on the scalars.

• irreps_gates (e3nn.o3.Irreps) – Representation of the scalars that will be passed through the activation functions act_gates and multiplied by the irreps_gated.

• act_gates (list of function or None) – Activation functions acting on the gates. The number of functions in the list should match the number of irrep groups in irreps_gates.

• irreps_gated (e3nn.o3.Irreps) – Representation of the gated tensors. irreps_gates.num_irreps == irreps_gated.num_irreps

Examples

>>> g = Gate("16x0o", [torch.tanh], "32x0o", [torch.tanh], "16x1e+16x1o")
>>> g.irreps_out
16x0o+16x1o+16x1e


Methods:

 forward(features) Evaluate the gated activation function.

Attributes:

 irreps_in Input representations. irreps_out Output representations.
forward(features)[source]

Evaluate the gated activation function.

Parameters

features (torch.Tensor) – tensor of shape (..., irreps_in.dim)

Returns

tensor of shape (..., irreps_out.dim)

Return type

torch.Tensor

property irreps_in[source]

Input representations.

property irreps_out[source]

Output representations.