Euclidean neural networks

What is e3nn?

e3nn is a python library based on pytorch to create equivariant neural networks for the group \(O(3)\).


All the functions to manipulate rotations (rotation matrices, Euler angles, quaternions, convertions, …) can be found here Parametrization of Rotations. The irreducible representations of \(O(3)\) (more info at Irreps) are represented by the class Irrep. The direct sum of multiple irrep is described by an object Irreps.

If two tensors \(x\) and \(y\) transforms as \(D_x = 2 \times 1_o\) (two vectors) and \(D_y = 0_e + 1_e\) (a scalar and a pseudovector) respectively, where the indices \(e\) and \(o\) stand for even and odd – the representation of parity,

import torch
from e3nn import o3

irreps_x = o3.Irreps('2x1o')
irreps_y = o3.Irreps('0e + 1e')

x = irreps_x.randn(-1)
y = irreps_y.randn(-1)

irreps_x.dim, irreps_y.dim
(6, 4)

their outer product is a \(6 \times 4\) matrix of two indices \(A_{ij} = x_i y_j\).

A = torch.einsum('i,j', x, y)
tensor([[-9.3048e-02,  6.2089e-02,  2.6819e-01,  4.9130e-03],
        [-1.2266e-02,  8.1851e-03,  3.5355e-02,  6.4767e-04],
        [-2.3882e-01,  1.5936e-01,  6.8835e-01,  1.2610e-02],
        [ 1.1373e+00, -7.5888e-01, -3.2779e+00, -6.0050e-02],
        [-8.6876e-03,  5.7971e-03,  2.5040e-02,  4.5872e-04],
        [ 8.6206e-01, -5.7524e-01, -2.4847e+00, -4.5518e-02]])

If a rotation is applied to the system, this matrix will transform with the representation \(D_x \otimes D_y\) (the tensor product representation).

\[A = x y^t \longrightarrow A' = D_x A D_y^t\]

Which can be represented by

R = o3.rand_matrix()
D_x = irreps_x.D_from_matrix(R)
D_y = irreps_y.D_from_matrix(R)

plt.imshow(torch.kron(D_x, D_y), cmap='bwr', vmin=-1, vmax=1);

This representation is not irreducible (is reducible). It can be decomposed into irreps by a change of basis. The outerproduct followed by the change of basis is done by the class FullTensorProduct.

tp = o3.FullTensorProduct(irreps_x, irreps_y)

tp(x, y)
FullTensorProduct(2x1o x 1x0e+1x1e -> 2x0o+4x1o+2x2o | 8 paths | 0 weights)
tensor([ 6.3540e-02, -4.4996e-01, -9.3048e-02, -1.2266e-02, -2.3882e-01,
         1.1373e+00, -8.6876e-03,  8.6206e-01, -4.8628e-01,  1.0921e-01,
         1.8385e-01,  1.7573e+00, -3.6429e-01, -2.3220e+00,  1.1616e-01,
         1.9543e-01, -1.6287e-03,  4.8719e-01, -3.4987e-02, -4.4921e-01,
        -2.3138e+00,  3.4884e-01, -1.7566e+00,  5.0442e-01])

As a sanity check, we can verify that the representation of the tensor prodcut is block diagonal and of the same dimension.

D = tp.irreps_out.D_from_matrix(R)
plt.imshow(D, cmap='bwr', vmin=-1, vmax=1);

FullTensorProduct is a special case of TensorProduct, other ones like FullyConnectedTensorProduct can contained weights what can be learned, very useful to create neural networks.