Gate
- class e3nn.nn.Activation(irreps_in, acts)[source]
Bases:
torch.nn.modules.module.Module
Scalar activation function.
Odd scalar inputs require activation functions with a defined parity (odd or even).
- Parameters
irreps_in (
e3nn.o3.Irreps
) – representation of the inputacts (list of function or None) – list of activation functions,
None
if non-scalar or identity
Examples
>>> a = Activation("256x0o", [torch.abs]) >>> a.irreps_out 256x0e
>>> a = Activation("256x0o+16x1e", [None, None]) >>> a.irreps_out 256x0o+16x1e
Methods:
forward
(features[, dim])evaluate
- forward(features, dim=- 1)[source]
evaluate
- Parameters
features (
torch.Tensor
) – tensor of shape(...)
- Returns
tensor of shape the same shape as the input
- Return type
- class e3nn.nn.Gate(irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated)[source]
Bases:
torch.nn.modules.module.Module
Gate activation function.
The gate activation is a direct sum of two sets of irreps. The first set of irreps is
irreps_scalars
passed through activation functionsact_scalars
. The second set of irreps isirreps_gated
multiplied by the scalarsirreps_gates
passed through activation functionsact_gates
. Mathematically, this can be written as:\[\left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right)\]where \(x_i\) and \(\phi_i\) are from
irreps_scalars
andact_scalars
, and \(g_j\), \(\phi_j\), and \(y_j\) are fromirreps_gates
,act_gates
, andirreps_gated
.The parameters passed in should adhere to the following conditions:
len(irreps_scalars) == len(act_scalars)
.len(irreps_gates) == len(act_gates)
.irreps_gates.num_irreps == irreps_gated.num_irreps
.
- Parameters
irreps_scalars (
e3nn.o3.Irreps
) – Representation of the scalars that will be passed through the activation functionsact_scalars
.act_scalars (list of function or None) – Activation functions acting on the scalars.
irreps_gates (
e3nn.o3.Irreps
) – Representation of the scalars that will be passed through the activation functionsact_gates
and multiplied by theirreps_gated
.act_gates (list of function or None) – Activation functions acting on the gates. The number of functions in the list should match the number of irrep groups in
irreps_gates
.irreps_gated (
e3nn.o3.Irreps
) – Representation of the gated tensors.irreps_gates.num_irreps == irreps_gated.num_irreps
Examples
>>> g = Gate("16x0o", [torch.tanh], "32x0o", [torch.tanh], "16x1e+16x1o") >>> g.irreps_out 16x0o+16x1o+16x1e
Methods:
forward
(features)Evaluate the gated activation function.
Attributes:
Input representations.
Output representations.
- forward(features)[source]
Evaluate the gated activation function.
- Parameters
features (
torch.Tensor
) – tensor of shape(..., irreps_in.dim)
- Returns
tensor of shape
(..., irreps_out.dim)
- Return type