Batch Normalization

class e3nn.nn.BatchNorm(irreps: Irreps, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, reduce: str = 'mean', instance: bool = False, normalization: str = 'component')[source]

Bases: Module

Batch normalization for orthonormal representations

It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations wigner_D are orthonormal.

Parameters:
  • irreps (o3.Irreps) – representation

  • eps (float) – avoid division by zero when we normalize by the variance

  • momentum (float) – momentum of the running average

  • affine (bool) – do we have weight and bias parameters

  • reduce ({'mean', 'max'}) – method used to reduce

  • instance (bool) – apply instance norm instead of batch norm

Methods:

forward(input)

evaluate

forward(input) Tensor[source]

evaluate

Parameters:

input (torch.Tensor) – tensor of shape (batch, ..., irreps.dim)

Returns:

tensor of shape (batch, ..., irreps.dim)

Return type:

torch.Tensor