Norm-Based Activation

class e3nn.nn.NormActivation(irreps_in: Irreps, scalar_nonlinearity: Callable, normalize: bool = True, epsilon: float | None = None, bias: bool = False)[source]

Bases: Module

Norm-based activation function Applies a scalar nonlinearity to the norm of each irrep and ouputs a (normalized) version of that irrep multiplied by the scalar output of the scalar nonlinearity. :param irreps_in: representation of the input :type irreps_in: e3nn.o3.Irreps :param scalar_nonlinearity: scalar nonlinearity such as torch.sigmoid :type scalar_nonlinearity: callable :param normalize: whether to normalize the input features before multiplying them by the scalars from the nonlinearity :type normalize: bool :param epsilon: when normalize``ing, norms smaller than ``epsilon will be clamped up to epsilon to avoid division by zero and

NaN gradients. Not allowed when normalize is False.

Parameters:

bias (bool) – whether to apply a learnable additive bias to the inputs of the scalar_nonlinearity

Examples

>>> n = NormActivation("2x1e", torch.sigmoid)
>>> feats = torch.ones(1, 2*3)
>>> print(feats.reshape(1, 2, 3).norm(dim=-1))
tensor([[1.7321, 1.7321]])
>>> print(torch.sigmoid(feats.reshape(1, 2, 3).norm(dim=-1)))
tensor([[0.8497, 0.8497]])
>>> print(n(feats).reshape(1, 2, 3).norm(dim=-1))
tensor([[0.8497, 0.8497]])

Methods:

forward(features)

evaluate :param features: tensor of shape (..., irreps_in.dim) :type features: torch.Tensor

forward(features)[source]

evaluate :param features: tensor of shape (..., irreps_in.dim) :type features: torch.Tensor

Returns:

tensor of shape (..., irreps_in.dim)

Return type:

torch.Tensor