Euclidean neural networks

What is e3nn?

e3nn is a python library based on pytorch to create equivariant neural networks for the group $$O(3)$$.

Where to start?

e3nn API

o3

All functions in this module are accessible via the o3 submodule:

from e3nn import o3

R = o3.rand_matrix(10)
D = o3.Irreps.spherical_harmonics(4).D_from_matrix(R)

Overview

Parametrization of Rotations
Matrix Parametrization

random rotation matrix

Parameters

*shape (int) –

Returns

tensor of shape $$(\mathrm{shape}, 3, 3)$$

Return type

torch.Tensor

e3nn.o3.matrix_x(angle: torch.Tensor) [source]

matrix of rotation around X axis

Parameters

angle (torch.Tensor) – tensor of any shape $$(...)$$

Returns

matrices of shape $$(..., 3, 3)$$

Return type

torch.Tensor

e3nn.o3.matrix_y(angle: torch.Tensor) [source]

matrix of rotation around Y axis

Parameters

angle (torch.Tensor) – tensor of any shape $$(...)$$

Returns

matrices of shape $$(..., 3, 3)$$

Return type

torch.Tensor

e3nn.o3.matrix_z(angle: torch.Tensor) [source]

matrix of rotation around Z axis

Parameters

angle (torch.Tensor) – tensor of any shape $$(...)$$

Returns

matrices of shape $$(..., 3, 3)$$

Return type

torch.Tensor

Euler Angles Parametrization

angles of the identity rotation

Parameters

*shape (int) –

Returns

• alpha (torch.Tensor) – tensor of shape $$(\mathrm{shape})$$

• beta (torch.Tensor) – tensor of shape $$(\mathrm{shape})$$

• gamma (torch.Tensor) – tensor of shape $$(\mathrm{shape})$$

random rotation angles

Parameters

*shape (int) –

Returns

• alpha (torch.Tensor) – tensor of shape $$(\mathrm{shape})$$

• beta (torch.Tensor) – tensor of shape $$(\mathrm{shape})$$

• gamma (torch.Tensor) – tensor of shape $$(\mathrm{shape})$$

e3nn.o3.compose_angles(a1, b1, c1, a2, b2, c2)[source]

compose angles

Computes $$(a, b, c)$$ such that $$R(a, b, c) = R(a_1, b_1, c_1) \circ R(a_2, b_2, c_2)$$

Parameters
• a1 (torch.Tensor) – tensor of shape $$(...)$$, (applied second)

• b1 (torch.Tensor) – tensor of shape $$(...)$$, (applied second)

• c1 (torch.Tensor) – tensor of shape $$(...)$$, (applied second)

• a2 (torch.Tensor) – tensor of shape $$(...)$$, (applied first)

• b2 (torch.Tensor) – tensor of shape $$(...)$$, (applied first)

• c2 (torch.Tensor) – tensor of shape $$(...)$$, (applied first)

Returns

e3nn.o3.inverse_angles(a, b, c)[source]

angles of the inverse rotation

Parameters
Returns

Quaternion Parametrization

quaternion of identity rotation

Parameters

*shape (int) –

Returns

tensor of shape $$(\mathrm{shape}, 4)$$

Return type

torch.Tensor

generate random quaternion

Parameters

*shape (int) –

Returns

tensor of shape $$(\mathrm{shape}, 4)$$

Return type

torch.Tensor

e3nn.o3.compose_quaternion(q1, q2)[source]

compose two quaternions: $$q_1 \circ q_2$$

Parameters
• q1 (torch.Tensor) – tensor of shape $$(..., 4)$$, (applied second)

• q2 (torch.Tensor) – tensor of shape $$(..., 4)$$, (applied first)

Returns

tensor of shape $$(..., 4)$$

Return type

torch.Tensor

e3nn.o3.inverse_quaternion(q)[source]

inverse of a quaternion

Works only for unit quaternions.

Parameters

q (torch.Tensor) – tensor of shape $$(..., 4)$$

Returns

tensor of shape $$(..., 4)$$

Return type

torch.Tensor

Axis-Angle Parametrization

generate random rotation as axis-angle

Parameters

*shape (int) –

Returns

• axis (torch.Tensor) – tensor of shape $$(\mathrm{shape}, 3)$$

• angle (torch.Tensor) – tensor of shape $$(\mathrm{shape})$$

e3nn.o3.compose_axis_angle(axis1, angle1, axis2, angle2)[source]

compose $$(\vec x_1, \alpha_1)$$ with $$(\vec x_2, \alpha_2)$$

Parameters
• axis1 (torch.Tensor) – tensor of shape $$(..., 3)$$, (applied second)

• angle1 (torch.Tensor) – tensor of shape $$(...)$$, (applied second)

• axis2 (torch.Tensor) – tensor of shape $$(..., 3)$$, (applied first)

• angle2 (torch.Tensor) – tensor of shape $$(...)$$, (applied first)

Returns

Convertions
e3nn.o3.angles_to_matrix(alpha, beta, gamma)[source]

conversion from angles to matrix

Parameters
Returns

matrices of shape $$(..., 3, 3)$$

Return type

torch.Tensor

e3nn.o3.matrix_to_angles(R)[source]

conversion from matrix to angles

Parameters

R (torch.Tensor) – matrices of shape $$(..., 3, 3)$$

Returns

e3nn.o3.angles_to_quaternion(alpha, beta, gamma)[source]

conversion from angles to quaternion

Parameters
Returns

matrices of shape $$(..., 4)$$

Return type

torch.Tensor

e3nn.o3.matrix_to_quaternion(R)[source]

conversion from matrix $$R$$ to quaternion $$q$$

Parameters

R (torch.Tensor) – tensor of shape $$(..., 3, 3)$$

Returns

tensor of shape $$(..., 4)$$

Return type

torch.Tensor

e3nn.o3.axis_angle_to_quaternion(xyz, angle)[source]

convertion from axis-angle to quaternion

Parameters
Returns

tensor of shape $$(..., 4)$$

Return type

torch.Tensor

e3nn.o3.quaternion_to_axis_angle(q)[source]

convertion from quaternion to axis-angle

Parameters

q (torch.Tensor) – tensor of shape $$(..., 4)$$

Returns

e3nn.o3.matrix_to_axis_angle(R)[source]

conversion from matrix to axis-angle

Parameters

R (torch.Tensor) – tensor of shape $$(..., 3, 3)$$

Returns

e3nn.o3.angles_to_axis_angle(alpha, beta, gamma)[source]

conversion from angles to axis-angle

Parameters
Returns

e3nn.o3.axis_angle_to_matrix(axis, angle)[source]

conversion from axis-angle to matrix

Parameters
Returns

tensor of shape $$(..., 3, 3)$$

Return type

torch.Tensor

e3nn.o3.quaternion_to_matrix(q)[source]

convertion from quaternion to matrix

Parameters

q (torch.Tensor) – tensor of shape $$(..., 4)$$

Returns

tensor of shape $$(..., 3, 3)$$

Return type

torch.Tensor

e3nn.o3.quaternion_to_angles(q)[source]

convertion from quaternion to angles

Parameters

q (torch.Tensor) – tensor of shape $$(..., 4)$$

Returns

e3nn.o3.axis_angle_to_angles(axis, angle)[source]

convertion from axis-angle to angles

Parameters
Returns

Convertions to point on the sphere
e3nn.o3.angles_to_xyz(alpha, beta)[source]

convert $$(\alpha, \beta)$$ into a point $$(x, y, z)$$ on the sphere

Parameters
Returns

tensor of shape $$(..., 3)$$

Return type

torch.Tensor

Examples

>>> angles_to_xyz(torch.tensor(1.7), torch.tensor(0.0)).abs()
tensor([0., 1., 0.])
e3nn.o3.xyz_to_angles(xyz)[source]

convert a point $$\vec r = (x, y, z)$$ on the sphere into angles $$(\alpha, \beta)$$

$\vec r = R(\alpha, \beta, 0) \vec e_z$
Parameters

xyz (torch.Tensor) – tensor of shape $$(..., 3)$$

Returns

Irreps

A group representation $$(D,V)$$ describe the action of a group $$G$$ on a vector space $$V$$

$D : G \longrightarrow \text{linear map on } V.$

The irreducible representations, in short irreps (definition of irreps) are the “smallest” representations.

• Any representation can be decomposed via a change of basis into a direct sum of irreps

• Any physical quantity, under the action of $$O(3)$$, transforms with a representation of $$O(3)$$

The irreps of $$SO(3)$$ are called the wigner matrices $$D^L$$. The irreps of the group of inversion ($$\{e, I\}$$) are the trivial representation $$\sigma_+$$ and the sign representation $$\sigma_-$$

$\begin{split}\sigma_p(g) = \left \{ \begin{array}{l} 1 \text{ if } g = e \\ p \text{ if } g = I \end{array} \right..\end{split}$

The group $$O(3)$$ is the direct product of $$SO(3)$$ and inversion

$g = r i, \quad r \in SO(3), i \in \text{inversion}.$

The irreps of $$O(3)$$ are the product of the irreps of $$SO(3)$$ and inversion. An instance of the class e3nn.o3.Irreps represent a direct sum of irreps of $$O(3)$$:

$g = r i \mapsto \bigoplus_{j=1}^n m_j \times \sigma_{p_j}(i) D^{L_j}(r)$

where $$(m_j \in \mathbb{N}, p_j = \pm 1, L_j = 0,1,2,3,\dots)_{j=1}^n$$ defines the e3nn.o3.Irreps.

Irreps of $$O(3)$$ are often confused with the spherical harmonics, the relation between the irreps and the spherical harmonics is explained at Spherical Harmonics.

class e3nn.o3.Irrep(l: Union[int, e3nn.o3._irreps.Irrep, str, tuple], p=None)[source]

Bases: tuple

Irreducible representation of $$O(3)$$

This class does not contain any data, it is a structure that describe the representation. It is typically used as argument of other classes of the library to define the input and output representations of functions.

Parameters
• l (int) – non-negative integer, the degree of the representation, $$l = 0, 1, \dots$$

• p ({1, -1}) – the parity of the representation

Examples

Create a scalar representation ($$l=0$$) of even parity.

>>> Irrep(0, 1)
0e

Create a pseudotensor representation ($$l=2$$) of odd parity.

>>> Irrep(2, -1)
2o

Create a vector representation ($$l=1$$) of the parity of the spherical harmonics ($$-1^l$$ gives odd parity).

>>> Irrep("1y")
1o
>>> Irrep("2o").dim
5
>>> Irrep("2e") in Irrep("1o") * Irrep("1o")
True
>>> Irrep("1o") + Irrep("2o")
1x1o+1x2o

Methods:

 D_from_angles(alpha, beta, gamma[, k]) Matrix $$p^k D^l(\alpha, \beta, \gamma)$$ Matrix of the representation, see Irrep.D_from_angles D_from_quaternion(q[, k]) Matrix of the representation, see Irrep.D_from_angles count(_value) Return number of occurrences of value. index(_value) Return first index of value. Equivalent to l == 0 and p == 1 iterator([lmax]) Iterator through all the irreps of $$O(3)$$

Attributes:

 dim The dimension of the representation, $$2 l + 1$$. l The degree of the representation, $$l = 0, 1, \dots$$. p The parity of the representation, $$p = \pm 1$$.
D_from_angles(alpha, beta, gamma, k=None)[source]

Matrix $$p^k D^l(\alpha, \beta, \gamma)$$

(matrix) Representation of $$O(3)$$. $$D$$ is the representation of $$SO(3)$$, see wigner_D.

Parameters
• alpha (torch.Tensor) – tensor of shape $$(...)$$ Rotation $$\alpha$$ around Y axis, applied third.

• beta (torch.Tensor) – tensor of shape $$(...)$$ Rotation $$\beta$$ around X axis, applied second.

• gamma (torch.Tensor) – tensor of shape $$(...)$$ Rotation $$\gamma$$ around Y axis, applied first.

• k (torch.Tensor, optional) – tensor of shape $$(...)$$ How many times the parity is applied.

Returns

tensor of shape $$(..., 2l+1, 2l+1)$$

Return type

torch.Tensor

o3.wigner_D, Irreps.D_from_angles

D_from_matrix(R)[source]

Matrix of the representation, see Irrep.D_from_angles

Parameters
Returns

tensor of shape $$(..., 2l+1, 2l+1)$$

Return type

torch.Tensor

Examples

>>> m = Irrep(1, -1).D_from_matrix(-torch.eye(3))
>>> m.long()
tensor([[-1,  0,  0],
[ 0, -1,  0],
[ 0,  0, -1]])
D_from_quaternion(q, k=None)[source]

Matrix of the representation, see Irrep.D_from_angles

Parameters
Returns

tensor of shape $$(..., 2l+1, 2l+1)$$

Return type

torch.Tensor

count(_value)[source]

Return number of occurrences of value.

property dim: int[source]

The dimension of the representation, $$2 l + 1$$.

index(_value)[source]

Return first index of value.

Raises ValueError if the value is not present.

is_scalar() bool[source]

Equivalent to l == 0 and p == 1

classmethod iterator(lmax=None)[source]

Iterator through all the irreps of $$O(3)$$

Examples

>>> it = Irrep.iterator()
>>> next(it), next(it), next(it), next(it)
(0e, 0o, 1o, 1e)
property l: int[source]

The degree of the representation, $$l = 0, 1, \dots$$.

property p: int[source]

The parity of the representation, $$p = \pm 1$$.

class e3nn.o3.Irreps(irreps=None)[source]

Bases: tuple

Direct sum of irreducible representations of $$O(3)$$

This class does not contain any data, it is a structure that describe the representation. It is typically used as argument of other classes of the library to define the input and output representations of functions.

dim[source]

the total dimension of the representation

Type

int

num_irreps[source]

number of irreps. the sum of the multiplicities

Type

int

ls[source]

list of $$l$$ values

Type

list of int

lmax[source]

maximum $$l$$ value

Type

int

Examples

Create a representation of 100 $$l=0$$ of even parity and 50 pseudo-vectors.

>>> x = Irreps([(100, (0, 1)), (50, (1, 1))])
>>> x
100x0e+50x1e
>>> x.dim
250

Create a representation of 100 $$l=0$$ of even parity and 50 pseudo-vectors.

>>> Irreps("100x0e + 50x1e")
100x0e+50x1e
>>> Irreps("100x0e + 50x1e + 0x2e")
100x0e+50x1e+0x2e
>>> Irreps("100x0e + 50x1e + 0x2e").lmax
1
>>> Irrep("2e") in Irreps("0e + 2e")
True

Empty Irreps

>>> Irreps(), Irreps("")
(, )

Methods:

 D_from_angles(alpha, beta, gamma[, k]) Matrix of the representation Matrix of the representation D_from_quaternion(q[, k]) Matrix of the representation count(ir) Multiplicity of ir. index(_object) Return first index of value. randn(*size[, normalization, requires_grad, ...]) Random tensor. Remove any irreps with multiplicities of zero. simplify() Simplify the representations. slices() List of slices corresponding to indices for each irrep. sort() Sort the representations. spherical_harmonics(lmax[, p]) representation of the spherical harmonics
D_from_angles(alpha, beta, gamma, k=None)[source]

Matrix of the representation

Parameters
Returns

tensor of shape $$(..., \mathrm{dim}, \mathrm{dim})$$

Return type

torch.Tensor

D_from_matrix(R)[source]

Matrix of the representation

Parameters

R (torch.Tensor) – tensor of shape $$(..., 3, 3)$$

Returns

tensor of shape $$(..., \mathrm{dim}, \mathrm{dim})$$

Return type

torch.Tensor

D_from_quaternion(q, k=None)[source]

Matrix of the representation

Parameters
Returns

tensor of shape $$(..., \mathrm{dim}, \mathrm{dim})$$

Return type

torch.Tensor

count(ir) int[source]

Multiplicity of ir.

Parameters

ir (e3nn.o3.Irrep) –

Returns

total multiplicity of ir

Return type

int

index(_object)[source]

Return first index of value.

Raises ValueError if the value is not present.

Random tensor.

Parameters
• *size (list of int) – size of the output tensor, needs to contains a -1

• normalization ({'component', 'norm'}) –

Returns

tensor of shape size where -1 is replaced by self.dim

Return type

torch.Tensor

Examples

>>> Irreps("5x0e + 10x1o").randn(5, -1, 5, normalization='norm').shape
torch.Size([5, 35, 5])
>>> random_tensor = Irreps("2o").randn(2, -1, 3, normalization='norm')
>>> random_tensor.norm(dim=1).sub(1).abs().max().item() < 1e-5
True
remove_zero_multiplicities()[source]

Remove any irreps with multiplicities of zero.

Returns

Return type

e3nn.o3.Irreps

Examples

>>> Irreps("4x0e + 0x1o + 2x3e").remove_zero_multiplicities()
4x0e+2x3e
simplify()[source]

Simplify the representations.

Returns

Return type

e3nn.o3.Irreps

Examples

Note that simplify does not sort the representations.

>>> Irreps("1e + 1e + 0e").simplify()
2x1e+1x0e

Equivalent representations which are separated from each other are not combined.

>>> Irreps("1e + 1e + 0e + 1e").simplify()
2x1e+1x0e+1x1e
slices()[source]

List of slices corresponding to indices for each irrep.

Examples

>>> Irreps('2x0e + 1e').slices()
[slice(0, 2, None), slice(2, 5, None)]
sort()[source]

Sort the representations.

Returns

Examples

>>> Irreps("1e + 0e + 1e").sort().irreps
1x0e+1x1e+1x1e
>>> Irreps("2o + 1e + 0e + 1e").sort().p
(3, 1, 0, 2)
>>> Irreps("2o + 1e + 0e + 1e").sort().inv
(2, 1, 3, 0)
static spherical_harmonics(lmax, p=- 1)[source]

representation of the spherical harmonics

Parameters
• lmax (int) – maximum $$l$$

• p ({1, -1}) – the parity of the representation

Returns

representation of $$(Y^0, Y^1, \dots, Y^{\mathrm{lmax}})$$

Return type

e3nn.o3.Irreps

Examples

>>> Irreps.spherical_harmonics(3)
1x0e+1x1o+1x2e+1x3o
>>> Irreps.spherical_harmonics(4, p=1)
1x0e+1x1e+1x2e+1x3e+1x4e
Tensor Product

All tensor products — denoted $$\otimes$$ — share two key characteristics:

1. The tensor product is bilinear: $$(\alpha x_1 + x_2) \otimes y = \alpha x_1 \otimes y + x_2 \otimes y$$ and $$x \otimes (\alpha y_1 + y_2) = \alpha x \otimes y_1 + x \otimes y_2$$

2. The tensor product is equivariant: $$(D x) \otimes (D y) = D (x \otimes y)$$ where $$D$$ is the representation of some symmetry operation from $$E(3)$$ (sorry for the very loose notation)

The class e3nn.o3.TensorProduct implements all possible tensor products between finite direct sums of irreducible representations (e3nn.o3.Irreps). While e3nn.o3.TensorProduct provides maximum flexibility, a number of sublcasses provide various typical special cases of the tensor product:

tp = o3.FullTensorProduct(
irreps_in1='2x0e + 3x1o',
irreps_in2='5x0e + 7x1e'
)
print(tp)
tp.visualize();
FullTensorProduct(2x0e+3x1o x 5x0e+7x1e -> 21x0o+10x0e+36x1o+14x1e+21x2o | 102 paths | 0 weights) The full tensor product is the “natural” one. Every possible output — each output irrep for every pair of input irreps — is created and returned independently. The outputs are not mixed with each other. Note how the multiplicities of the outputs are the product of the multiplicities of the respective inputs.

tp = o3.FullyConnectedTensorProduct(
irreps_in1='5x0e + 5x1e',
irreps_in2='6x0e + 4x1e',
irreps_out='15x0e + 3x1e'
)
print(tp)
tp.visualize();
FullyConnectedTensorProduct(5x0e+5x1e x 6x0e+4x1e -> 15x0e+3x1e | 960 paths | 960 weights) In a fully connected tensor product, all paths that lead to any of the irreps specified in irreps_out are created. Unlike e3nn.o3.FullTensorProduct, each output is a learned weighted sum of compatible paths. This allows e3nn.o3.FullyConnectedTensorProduct to produce outputs with any multiplicity; note that the example above has $$5 \times 6 + 5 \times 4 = 50$$ ways of creating scalars (0e), but the specified irreps_out has only 15 scalars, each of which is a learned weighted combination of those 50 possible scalars. The blue color in the visualization indicates that the path has these learnable weights.

All possible output irreps do not need to be included in irreps_out of a e3nn.o3.FullyConnectedTensorProduct: o3.FullyConnectedTensorProduct(irreps_in1='5x1o', irreps_in2='3x1o', irreps_out='20x0e') will only compute inner products between its inputs, since 1e, the output irrep of a vector cross product, is not present in irreps_out. Note also in this example that there are 20 output scalars, even though the given inputs can produce only 15 unique scalars — this is again allowed because each output is a learned linear combination of those 15 scalars, placing no restrictions on how many or how few outputs can be requested.

tp = o3.ElementwiseTensorProduct(
irreps_in1='5x0e + 5x1e',
irreps_in2='4x0e + 6x1e'
)
print(tp)
tp.visualize();
ElementwiseTensorProduct(5x0e+5x1e x 4x0e+6x1e -> 4x0e+1x1e+5x0e+5x1e+5x2e | 20 paths | 0 weights) In the elementwise tensor product, the irreps are multiplied one-by-one. Note in the visualization how the inputs have been split and that the multiplicities of the outputs match with the multiplicities of the input.

tp = o3.TensorSquare("5x1e + 2e")
print(tp)
tp.visualize();
TensorSquare(5x1e+1x2e -> 16x0e+15x1e+21x2e+5x3e+1x4e | 58 paths | 0 weights) The tensor square operation only computes the non-zero entries of a tensor times itself. It also applies different normalization rules taking into account that a tensor time itself is statistically different from the product of two independent tensors.

class e3nn.o3.TensorProduct(irreps_in1: e3nn.o3._irreps.Irreps, irreps_in2: e3nn.o3._irreps.Irreps, irreps_out: e3nn.o3._irreps.Irreps, instructions: List[tuple], in1_var: Optional[Union[List[float], torch.Tensor]] = None, in2_var: Optional[Union[List[float], torch.Tensor]] = None, out_var: Optional[Union[List[float], torch.Tensor]] = None, irrep_normalization: Optional[str] = None, path_normalization: Optional[str] = None, internal_weights: Optional[bool] = None, shared_weights: Optional[bool] = None, compile_left_right: bool = True, compile_right: bool = False, normalization=None, _specialized_code: Optional[bool] = None, _optimize_einsums: Optional[bool] = None)[source]

Bases: e3nn.util.codegen._mixin.CodeGenMixin, torch.nn.modules.module.Module

Tensor product with parametrized paths.

Parameters
• irreps_in1 (e3nn.o3.Irreps) – Irreps for the first input.

• irreps_in2 (e3nn.o3.Irreps) – Irreps for the second input.

• irreps_out (e3nn.o3.Irreps) – Irreps for the output.

• instructions (list of tuple) –

List of instructions (i_1, i_2, i_out, mode, train[, path_weight]).

Each instruction puts in1[i_1] $$\otimes$$ in2[i_2] into out[i_out].

• mode: str. Determines the way the multiplicities are treated, "uvw" is fully connected. Other valid options are: 'uvw', 'uvu', 'uvv', 'uuw', 'uuu', and 'uvuv'.

• train: bool. True if this path should have learnable weights, otherwise False.

• path_weight: float. A fixed multiplicative weight to apply to the output of this path. Defaults to 1. Note that setting path_weight breaks the normalization derived from in1_var/in2_var/out_var.

• in1_var (list of float, Tensor, or None) – Variance for each irrep in irreps_in1. If None, all default to 1.0.

• in2_var (list of float, Tensor, or None) – Variance for each irrep in irreps_in2. If None, all default to 1.0.

• out_var (list of float, Tensor, or None) – Variance for each irrep in irreps_out. If None, all default to 1.0.

• irrep_normalization ({'component', 'norm'}) –

The assumed normalization of the input and output representations. If it is set to “norm”:

$\| x \| = \| y \| = 1 \Longrightarrow \| x \otimes y \| = 1$

• path_normalization ({'element', 'path'}) – If set to element, each output is normalized by the total number of elements (independently of their paths). If it is set to path, each path is normalized by the total number of elements in the path, then each output is normalized by the number of paths.

• internal_weights (bool) – whether the e3nn.o3.TensorProduct contains its learnable weights as a parameter

• shared_weights (bool) –

whether the learnable weights are shared among the input’s extra dimensions

• True $$z_i = w x_i \otimes y_i$$

• False $$z_i = w_i x_i \otimes y_i$$

where here $$i$$ denotes a batch-like index. shared_weights cannot be False if internal_weights is True.

• compile_left_right (bool) – whether to compile the forward function, true by default

• compile_right (bool) – whether to compile the .right function, false by default

Examples

Create a module that computes elementwise the cross-product of 16 vectors with 16 vectors $$z_u = x_u \wedge y_u$$

>>> module = TensorProduct(
...     "16x1o", "16x1o", "16x1e",
...     [
...         (0, 0, 0, "uuu", False)
...     ]
... )

Now mix all 16 vectors with all 16 vectors to makes 16 pseudo-vectors $$z_w = \sum_{u,v} w_{uvw} x_u \wedge y_v$$

>>> module = TensorProduct(
...     [(16, (1, -1))],
...     [(16, (1, -1))],
...     [(16, (1,  1))],
...     [
...         (0, 0, 0, "uvw", True)
...     ]
... )

With custom input variance and custom path weights:

>>> module = TensorProduct(
...     "8x0o + 8x1o",
...     "16x1o",
...     "16x1e",
...     [
...         (0, 0, 0, "uvw", True, 3),
...         (1, 0, 0, "uvw", True, 1),
...     ],
...     in2_var=[1/16]
... )

Example of a dot product:

>>> irreps = o3.Irreps("3x0e + 4x0o + 1e + 2o + 3o")
>>> module = TensorProduct(irreps, irreps, "0e", [
...     (i, i, 0, 'uuw', False)
...     for i, (mul, ir) in enumerate(irreps)
... ])

Implement $$z_u = x_u \otimes (\sum_v w_{uv} y_v)$$

>>> module = TensorProduct(
...     "8x0o + 7x1o + 3x2e",
...     "10x0e + 10x1e + 10x2e",
...     "8x0o + 7x1o + 3x2e",
...     [
...         # paths for the l=0:
...         (0, 0, 0, "uvu", True),  # 0x0->0
...         # paths for the l=1:
...         (1, 0, 1, "uvu", True),  # 1x0->1
...         (1, 1, 1, "uvu", True),  # 1x1->1
...         (1, 2, 1, "uvu", True),  # 1x2->1
...         # paths for the l=2:
...         (2, 0, 2, "uvu", True),  # 2x0->2
...         (2, 1, 2, "uvu", True),  # 2x1->2
...         (2, 2, 2, "uvu", True),  # 2x2->2
...     ]
... )

Tensor Product using the xavier uniform initialization:

>>> irreps_1 = o3.Irreps("5x0e + 10x1o + 1x2e")
>>> irreps_2 = o3.Irreps("5x0e + 10x1o + 1x2e")
>>> irreps_out = o3.Irreps("5x0e + 10x1o + 1x2e")
>>> # create a Fully Connected Tensor Product
>>> module = o3.TensorProduct(
...     irreps_1,
...     irreps_2,
...     irreps_out,
...     [
...         (i_1, i_2, i_out, "uvw", True, mul_1 * mul_2)
...         for i_1, (mul_1, ir_1) in enumerate(irreps_1)
...         for i_2, (mul_2, ir_2) in enumerate(irreps_2)
...         for i_out, (mul_out, ir_out) in enumerate(irreps_out)
...         if ir_out in ir_1 * ir_2
...     ]
... )
...     for weight in module.weight_views():
...         mul_1, mul_2, mul_out = weight.shape
...         # formula from torch.nn.init.xavier_uniform_
...         a = (6 / (mul_1 * mul_2 + mul_out))**0.5
...         new_weight = torch.empty_like(weight)
...         new_weight.uniform_(-a, a)
...         weight[:] = new_weight
tensor(...)
>>> n = 1_000
>>> vars = module(irreps_1.randn(n, -1), irreps_2.randn(n, -1)).var(0)
>>> assert vars.min() > 1 / 3
>>> assert vars.max() < 3

Methods:

 forward(x, y[, weight]) Evaluate $$w x \otimes y$$. right(y[, weight]) Partially evaluate $$w x \otimes y$$. visualize([weight, plot_weight, ...]) Visualize the connectivity of this e3nn.o3.TensorProduct weight_view_for_instruction(instruction[, ...]) View of weights corresponding to instruction. weight_views([weight, yield_instruction]) Iterator over weight views for each weighted instruction.
forward(x, y, weight: Optional[torch.Tensor] = None)[source]

Evaluate $$w x \otimes y$$.

Parameters
• x (torch.Tensor) – tensor of shape (..., irreps_in1.dim)

• y (torch.Tensor) – tensor of shape (..., irreps_in2.dim)

• weight (torch.Tensor or list of torch.Tensor, optional) – required if internal_weights is False tensor of shape (self.weight_numel,) if shared_weights is True tensor of shape (..., self.weight_numel) if shared_weights is False or list of tensors of shapes weight_shape / (...) + weight_shape. Use self.instructions to know what are the weights used for.

Returns

tensor of shape (..., irreps_out.dim)

Return type

torch.Tensor

right(y, weight: Optional[torch.Tensor] = None)[source]

Partially evaluate $$w x \otimes y$$.

It returns an operator in the form of a tensor that can act on an arbitrary $$x$$.

For example, if the tensor product above is expressed as

$w_{ijk} x_i y_j \rightarrow z_k$

then the right method returns a tensor $$b_{ik}$$ such that

$w_{ijk} y_j \rightarrow b_{ik} x_i b_{ik} \rightarrow z_k$

The result of this method can be applied with a tensor contraction:

torch.einsum("...ik,...i->...k", right, input)
Parameters
• y (torch.Tensor) – tensor of shape (..., irreps_in2.dim)

• weight (torch.Tensor or list of torch.Tensor, optional) – required if internal_weights is False tensor of shape (self.weight_numel,) if shared_weights is True tensor of shape (..., self.weight_numel) if shared_weights is False or list of tensors of shapes weight_shape / (...) + weight_shape. Use self.instructions to know what are the weights used for.

Returns

tensor of shape (..., irreps_in1.dim, irreps_out.dim)

Return type

torch.Tensor

visualize(weight: Optional[torch.Tensor] = None, plot_weight: bool = True, aspect_ratio=1, ax=None)[source]

Visualize the connectivity of this e3nn.o3.TensorProduct

Parameters
• weight (torch.Tensor, optional) – like weight argument to forward()

• plot_weight (bool, default True) – Whether to color paths by the sum of their weights.

• ax (matplotlib.Axes, default None) – The axes to plot on. If None, a new figure will be created.

Returns

The figure and axes on which the plot was drawn.

Return type

(fig, ax)

weight_view_for_instruction(instruction: int, weight: Optional[torch.Tensor] = None) [source]

View of weights corresponding to instruction.

Parameters
• instruction (int) – The index of the instruction to get a view on the weights for. self.instructions[instruction].has_weight must be True.

• weight (torch.Tensor, optional) – like weight argument to forward()

Returns

A view on weight or this object’s internal weights for the weights corresponding to the instruction th instruction.

Return type

torch.Tensor

weight_views(weight: Optional[torch.Tensor] = None, yield_instruction: bool = False)[source]

Iterator over weight views for each weighted instruction.

Parameters
• weight (torch.Tensor, optional) – like weight argument to forward()

• yield_instruction (bool, default False) – Whether to also yield the corresponding instruction.

Yields
• If yield_instruction is True, yields (instruction_index, instruction, weight_view).

• Otherwise, yields weight_view.

class e3nn.o3.FullyConnectedTensorProduct(irreps_in1, irreps_in2, irreps_out, irrep_normalization: Optional[str] = None, path_normalization: Optional[str] = None, **kwargs)[source]

Fully-connected weighted tensor product

All the possible path allowed by $$|l_1 - l_2| \leq l_{out} \leq l_1 + l_2$$ are made. The output is a sum on different paths:

$z_w = \sum_{u,v} w_{uvw} x_u \otimes y_v + \cdots \text{other paths}$

where $$u,v,w$$ are the indices of the multiplicities.

Parameters
class e3nn.o3.FullTensorProduct(irreps_in1: e3nn.o3._irreps.Irreps, irreps_in2: e3nn.o3._irreps.Irreps, filter_ir_out: Optional[Iterator[e3nn.o3._irreps.Irrep]] = None, irrep_normalization: Optional[str] = None, **kwargs)[source]

Full tensor product between two irreps.

$z_{uv} = x_u \otimes y_v$

where $$u$$ and $$v$$ run over the irreps. Note that there are no weights. The output representation is determined by the two input representations.

Parameters
class e3nn.o3.ElementwiseTensorProduct(irreps_in1, irreps_in2, filter_ir_out=None, irrep_normalization: Optional[str] = None, **kwargs)[source]

Elementwise connected tensor product.

$z_u = x_u \otimes y_u$

where $$u$$ runs over the irreps. Note that there are no weights. The output representation is determined by the two input representations.

Parameters

Examples

Elementwise scalar product

>>> ElementwiseTensorProduct("5x1o + 5x1e", "10x1e", ["0e", "0o"])
ElementwiseTensorProduct(5x1o+5x1e x 10x1e -> 5x0o+5x0e | 10 paths | 0 weights)
class e3nn.o3.TensorSquare(irreps_in: e3nn.o3._irreps.Irreps, irreps_out: Optional[e3nn.o3._irreps.Irreps] = None, filter_ir_out: Optional[Iterator[e3nn.o3._irreps.Irrep]] = None, irrep_normalization: Optional[str] = None, **kwargs)[source]

Compute the square tensor product of a tensor and reduce it in irreps

This module contains no parameters. The output representation is determined by the input representation.

Parameters

Methods:

 forward(x[, weight]) Evaluate $$w x \otimes y$$.
forward(x, weight: Optional[torch.Tensor] = None)[source]

Evaluate $$w x \otimes y$$.

Parameters
• x (torch.Tensor) – tensor of shape (..., irreps_in1.dim)

• y (torch.Tensor) – tensor of shape (..., irreps_in2.dim)

• weight (torch.Tensor or list of torch.Tensor, optional) – required if internal_weights is False tensor of shape (self.weight_numel,) if shared_weights is True tensor of shape (..., self.weight_numel) if shared_weights is False or list of tensors of shapes weight_shape / (...) + weight_shape. Use self.instructions to know what are the weights used for.

Returns

tensor of shape (..., irreps_out.dim)

Return type

torch.Tensor

Spherical Harmonics

The spherical harmonics $$Y^l(x)$$ are functions defined on the sphere $$S^2$$. They form a basis of the space on function on the sphere:

$\mathcal{F} = \{ S^2 \longrightarrow \mathbb{R} \}$

On this space it is nautal how the group $$O(3)$$ acts, Given $$p_a, p_v$$ two scalar representations:

$[L(g) f](x) = p_v(g) f(p_a(g) R(g)^{-1} x), \quad \forall f \in \mathcal{F}, x \in S^2$

$$L$$ is representation of $$O(3)$$. But $$L$$ is not irreducible. It can be decomposed via a change of basis into a sum of irreps,

$Y^T L(g) Y = 0 \oplus 1 \oplus 2 \oplus 3 \oplus \dots$

where the change of basis are the spherical harmonics!

As a consequence, the spherical harmonics are equivariant,

$Y^l(R(g) x) = D^l(g) Y^l(x)$
r = s2_grid()

r is a grid on the sphere.