Euclidean neural networks

What is e3nn?

e3nn is a python library based on pytorch to create equivariant neural networks for the group \(O(3)\).

Where to start?


All the functions to manipulate rotations (rotation matrices, Euler angles, quaternions, convertions, …) can be found here Parametrization of Rotations. The irreducible representations of \(O(3)\) (more info at Irreps) are represented by the class e3nn.o3.Irrep. The direct sum of multiple irrep is described by an object e3nn.o3.Irreps.

If two tensors \(x\) and \(y\) transforms as \(D_x = 2 \times 1_o\) (two vectors) and \(D_y = 0_e + 1_e\) (a scalar and a pseudovector) respectively, where the indices \(e\) and \(o\) stand for even and odd – the representation of parity,

import torch
from e3nn import o3

irreps_x = o3.Irreps('2x1o')
irreps_y = o3.Irreps('0e + 1e')

x = irreps_x.randn(-1)
y = irreps_y.randn(-1)

irreps_x.dim, irreps_y.dim
(6, 4)

their outer product is a \(6 \times 4\) matrix of two indices \(A_{ij} = x_i y_j\).

A = torch.einsum('i,j', x, y)
tensor([[-0.5411, -0.0199,  0.1181, -0.3354],
        [ 2.7620,  0.1015, -0.6030,  1.7119],
        [-3.8228, -0.1405,  0.8346, -2.3693],
        [-0.1955, -0.0072,  0.0427, -0.1212],
        [-0.4380, -0.0161,  0.0956, -0.2715],
        [ 0.3092,  0.0114, -0.0675,  0.1916]])

If a rotation is applied to the system, this matrix will transform with the representation \(D_x \otimes D_y\) (the tensor product representation).

\[A = x y^t \longrightarrow A' = D_x A D_y^t\]

Which can be represented by

R = o3.rand_matrix()
D_x = irreps_x.D_from_matrix(R)
D_y = irreps_y.D_from_matrix(R)

plt.imshow(torch.kron(D_x, D_y), cmap='bwr', vmin=-1, vmax=1);

This representation is not irreducible (is reducible). It can be decomposed into irreps by a change of basis. The outerproduct followed by the change of basis is done by the class e3nn.o3.FullTensorProduct.

tp = o3.FullTensorProduct(irreps_x, irreps_y)

tp(x, y)
FullTensorProduct(2x1o x 1x0e+1x1e -> 2x0o+4x1o+2x2o | 8 paths | 0 weights)
tensor([-1.7276e+00,  1.6170e-01, -5.4115e-01,  2.7620e+00, -3.8228e+00,
        -1.9555e-01, -4.3801e-01,  3.0920e-01,  6.2032e-01,  1.3780e-01,
         1.1750e-02, -1.4423e-01,  9.3737e-02,  4.1573e-02, -3.3653e-01,
         1.5533e-01,  4.8305e-01,  1.8006e+00, -1.6613e+00, -7.7663e-02,
         1.8803e-02,  2.7762e-03, -2.3970e-01,  1.4059e-01])

As a sanity check, we can verify that the representation of the tensor prodcut is block diagonal and of the same dimension.

D = tp.irreps_out.D_from_matrix(R)
plt.imshow(D, cmap='bwr', vmin=-1, vmax=1);

e3nn.o3.FullTensorProduct is a special case of e3nn.o3.TensorProduct, other ones like e3nn.o3.FullyConnectedTensorProduct can contained weights what can be learned, very useful to create neural networks.