Euclidean neural networks¶

What is e3nn?¶

e3nn is a python library based on pytorch to create equivariant neural networks for the group $$O(3)$$.

Demonstration¶

All the functions to manipulate rotations (rotation matrices, Euler angles, quaternions, convertions, …) can be found here Parametrization of Rotations. The irreducible representations of $$O(3)$$ (more info at Irreps) are represented by the class Irrep. The direct sum of multiple irrep is described by an object Irreps.

If two tensors $$x$$ and $$y$$ transforms as $$D_x = 2 \times 1_o$$ (two vectors) and $$D_y = 0_e + 1_e$$ (a scalar and a pseudovector) respectively, where the indices $$e$$ and $$o$$ stand for even and odd – the representation of parity,

import torch
from e3nn import o3

irreps_x = o3.Irreps('2x1o')
irreps_y = o3.Irreps('0e + 1e')

x = irreps_x.randn(-1)
y = irreps_y.randn(-1)

irreps_x.dim, irreps_y.dim

(6, 4)


their outer product is a $$6 \times 4$$ matrix of two indices $$A_{ij} = x_i y_j$$.

A = torch.einsum('i,j', x, y)
A

tensor([[-9.3048e-02,  6.2089e-02,  2.6819e-01,  4.9130e-03],
[-1.2266e-02,  8.1851e-03,  3.5355e-02,  6.4767e-04],
[-2.3882e-01,  1.5936e-01,  6.8835e-01,  1.2610e-02],
[ 1.1373e+00, -7.5888e-01, -3.2779e+00, -6.0050e-02],
[-8.6876e-03,  5.7971e-03,  2.5040e-02,  4.5872e-04],
[ 8.6206e-01, -5.7524e-01, -2.4847e+00, -4.5518e-02]])


If a rotation is applied to the system, this matrix will transform with the representation $$D_x \otimes D_y$$ (the tensor product representation).

$A = x y^t \longrightarrow A' = D_x A D_y^t$

Which can be represented by

R = o3.rand_matrix()
D_x = irreps_x.D_from_matrix(R)
D_y = irreps_y.D_from_matrix(R)

plt.imshow(torch.kron(D_x, D_y), cmap='bwr', vmin=-1, vmax=1);


This representation is not irreducible (is reducible). It can be decomposed into irreps by a change of basis. The outerproduct followed by the change of basis is done by the class FullTensorProduct.

tp = o3.FullTensorProduct(irreps_x, irreps_y)
print(tp)

tp(x, y)

FullTensorProduct(2x1o x 1x0e+1x1e -> 2x0o+4x1o+2x2o | 8 paths | 0 weights)

tensor([ 6.3540e-02, -4.4996e-01, -9.3048e-02, -1.2266e-02, -2.3882e-01,
1.1373e+00, -8.6876e-03,  8.6206e-01, -4.8628e-01,  1.0921e-01,
1.8385e-01,  1.7573e+00, -3.6429e-01, -2.3220e+00,  1.1616e-01,
1.9543e-01, -1.6287e-03,  4.8719e-01, -3.4987e-02, -4.4921e-01,
-2.3138e+00,  3.4884e-01, -1.7566e+00,  5.0442e-01])


As a sanity check, we can verify that the representation of the tensor prodcut is block diagonal and of the same dimension.

D = tp.irreps_out.D_from_matrix(R)
plt.imshow(D, cmap='bwr', vmin=-1, vmax=1);


FullTensorProduct is a special case of TensorProduct, other ones like FullyConnectedTensorProduct can contained weights what can be learned, very useful to create neural networks.