Gate
- class e3nn.nn.Activation(irreps_in, acts)[source]
Bases:
ModuleScalar activation function.
Odd scalar inputs require activation functions with a defined parity (odd or even).
- Parameters:
irreps_in (
e3nn.o3.Irreps) – representation of the inputacts (list of function or None) – list of activation functions,
Noneif non-scalar or identity
Examples
>>> a = Activation("256x0o", [torch.abs]) >>> a.irreps_out 256x0e
>>> a = Activation("256x0o+16x1e", [None, None]) >>> a.irreps_out 256x0o+16x1e
Methods:
forward(features[, dim])evaluate
- forward(features, dim: int = -1)[source]
evaluate
- Parameters:
features (
torch.Tensor) – tensor of shape(...)- Returns:
tensor of shape the same shape as the input
- Return type:
- class e3nn.nn.Gate(irreps_scalars, act_scalars, irreps_gates, act_gates, irreps_gated)[source]
Bases:
ModuleGate activation function.
The gate activation is a direct sum of two sets of irreps. The first set of irreps is
irreps_scalarspassed through activation functionsact_scalars. The second set of irreps isirreps_gatedmultiplied by the scalarsirreps_gatespassed through activation functionsact_gates. Mathematically, this can be written as:\[\left(\bigoplus_i \phi_i(x_i) \right) \oplus \left(\bigoplus_j \phi_j(g_j) y_j \right)\]where \(x_i\) and \(\phi_i\) are from
irreps_scalarsandact_scalars, and \(g_j\), \(\phi_j\), and \(y_j\) are fromirreps_gates,act_gates, andirreps_gated.The parameters passed in should adhere to the following conditions:
len(irreps_scalars) == len(act_scalars).len(irreps_gates) == len(act_gates).irreps_gates.num_irreps == irreps_gated.num_irreps.
- Parameters:
irreps_scalars (
e3nn.o3.Irreps) – Representation of the scalars that will be passed through the activation functionsact_scalars.act_scalars (list of function or None) – Activation functions acting on the scalars.
irreps_gates (
e3nn.o3.Irreps) – Representation of the scalars that will be passed through the activation functionsact_gatesand multiplied by theirreps_gated.act_gates (list of function or None) – Activation functions acting on the gates. The number of functions in the list should match the number of irrep groups in
irreps_gates.irreps_gated (
e3nn.o3.Irreps) – Representation of the gated tensors.irreps_gates.num_irreps == irreps_gated.num_irreps
Examples
>>> g = Gate("16x0o", [torch.tanh], "32x0o", [torch.tanh], "16x1e+16x1o") >>> g.irreps_out 16x0o+16x1o+16x1e
Methods:
forward(features)Evaluate the gated activation function.
Attributes:
Input representations.
Output representations.
- forward(features)[source]
Evaluate the gated activation function.
- Parameters:
features (
torch.Tensor) – tensor of shape(..., irreps_in.dim)- Returns:
tensor of shape
(..., irreps_out.dim)- Return type: