Batch Normalization

class e3nn.nn.BatchNorm(irreps, eps=1e-05, momentum=0.1, affine=True, reduce='mean', instance=False, normalization='component')[source]

Bases: torch.nn.modules.module.Module

Batch normalization for orthonormal representations

It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations wigner_D are orthonormal.

Parameters
  • irreps (o3.Irreps) – representation

  • eps (float) – avoid division by zero when we normalize by the variance

  • momentum (float) – momentum of the running average

  • affine (bool) – do we have weight and bias parameters

  • reduce ({'mean', 'max'}) – method used to reduce

  • instance (bool) – apply instance norm instead of batch norm

Methods:

forward(input)

evaluate

forward(input)[source]

evaluate

Parameters

input (torch.Tensor) – tensor of shape (batch, ..., irreps.dim)

Returns

tensor of shape (batch, ..., irreps.dim)

Return type

torch.Tensor