Normalization

We define two kind of normalizations: component and norm.

Definition

component

component normalization refers to tensors with each component of value around 1. More precisely, the second moment of each component is 1.

\[\langle x_i^2 \rangle = 1\]

Examples:

  • [1.0, -1.0, -1.0, 1.0]

  • [1.0, 1.0, 1.0, 1.0] the mean don’t need to be zero

  • [0.0, 2.0, 0.0, 0.0] this is still fine because \(\|x\|^2 = n\)

torch.randn(10)
tensor([ 0.0287,  0.4414, -0.0600, -0.3662, -2.2076,  2.3285, -0.0219,  0.2407,
         0.5488,  0.4898])

norm

norm normalization refers to tensors of norm close to 1.

\[\|x\| \approx 1\]

Examples:

  • [0.5, -0.5, -0.5, 0.5]

  • [0.5, 0.5, 0.5, 0.5] the mean don’t need to be zero

  • [0.0, 1.0, 0.0, 0.0]

torch.randn(10) / 10**0.5
tensor([ 0.2117, -0.6425,  0.1753, -0.0785, -0.0275, -0.2487, -0.1967, -0.2721,
        -0.1275,  0.3641])

There is just a factor \(\sqrt{n}\) between the two normalizations.

Motivation

Assuming that the weights distribution obey

\[ \begin{align}\begin{aligned}\langle w_i \rangle = 0\\\langle w_i w_j \rangle = \sigma^2 \delta_{ij}\end{aligned}\end{align} \]

It imply that the two first moments of \(x \cdot w\) (and therefore mean and variance) are only function of the second moment of \(x\)

\[ \begin{align}\begin{aligned}\langle x \cdot w \rangle &= \sum_i \langle x_i w_i \rangle = \sum_i \langle x_i \rangle \langle w_i \rangle = 0\\\langle (x \cdot w)^2 \rangle &= \sum_{i} \sum_{j} \langle x_i w_i x_j w_j \rangle\\ &= \sum_{i} \sum_{j} \langle x_i x_j \rangle \langle w_i w_j \rangle\\ &= \sigma^2 \sum_{i} \langle x_i^2 \rangle\end{aligned}\end{align} \]

Testing

You can use e3nn.util.test.assert_normalized to check whether a function or module is normalized at initialization:

from e3nn.util.test import assert_normalized
from e3nn import o3
assert_normalized(o3.Linear("10x0e", "10x0e"))