Spherical Activation

class e3nn.nn.S2Activation(irreps: e3nn.o3._irreps.Irreps, act, res, normalization='component', lmax_out=None, random_rot=False)[source]

Bases: torch.nn.modules.module.Module

Apply non linearity on the signal on the sphere

Maps to the sphere, apply the non linearity point wise and project back.
The signal on the sphere is a quasiregular representation of \(O(3)\) and we can apply a pointwise operation on these representations.
\[\{A^l\}_l \mapsto \{\int \phi(\sum_l A^l \cdot Y^l(x)) Y^j(x) dx\}_j\]
Parameters
  • irreps (o3.Irreps) – input representation of the form [(1, (l, p_val * (p_arg)^l)) for l in [0, ..., lmax]]

  • act (function) – activation function \(\phi\)

  • res (int) – resolution of the grid on the sphere (the higher the more accurate)

  • normalization ({'norm', 'component'}) –

  • lmax_out (int, optional) – maximum l of the output

  • random_rot (bool) – rotate randomly the grid

Examples

>>> from e3nn import io
>>> m = S2Activation(io.SphericalTensor(5, p_val=+1, p_arg=-1), torch.tanh, 100)

Methods:

forward(features)

evaluate

forward(features)[source]

evaluate

Parameters

features (torch.Tensor) – tensor \(\{A^l\}_l\) of shape (..., self.irreps_in.dim)

Returns

tensor of shape (..., self.irreps_out.dim)

Return type

torch.Tensor