Equivariance Testing

In e3nn.util.test, the library provides some tools for confirming that functions are equivariant. The main tool is equivariance_error, which computes the largest absolute change in output between the function applied to transformed arguments and the transform applied to the function:

import e3nn.o3
from e3nn.util.test import equivariance_error

tp = e3nn.o3.FullyConnectedTensorProduct("2x0e + 3x1o", "2x0e + 3x1o", "2x1o")

    args_in=[tp.irreps_in1.randn(1, -1), tp.irreps_in2.randn(1, -1)],
    irreps_in=[tp.irreps_in1, tp.irreps_in2],
{(0, False): tensor([8.2248e-08]), (1, False): tensor([8.3420e-08])}

The keys in the output indicate the type of random transformation ((parity, did_translation)) and the values are the maximum componentwise error. For convenience, the wrapper function assert_equivariant is provided:

from e3nn.util.test import assert_equivariant

{(0, False): tensor([4.1977e-08]), (1, False): tensor([3.0510e-08])}

For typical e3nn operations assert_equivariant can optionally infer the input and output e3nn.o3.Irreps, generate random inputs when no inputs are provided, and check the error against a threshold appropriate to the current torch.get_default_dtype().

In addition to e3nn.o3.Irreps-like objects, irreps_in can also contain two special values:

  • 'cartesian_points': (N, 3) tensors containing XYZ points in real space that are equivariant under rotations and translations

  • None: any input or output that is invariant and should be left alone

These can be used to test models that operate on full graphs that include position information:

import torch
from torch_geometric.data import Data
from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork
from e3nn.util.test import assert_equivariant

# kwargs = ...
f = SimpleNetwork(**kwargs)

def wrapper(pos, x):
    data = dict(pos=pos, x=x)
    return f(data)

    irreps_in=['cartesian_points', f.irreps_in],
{(0, False): tensor([5.7631e-08]),
 (0, True): tensor([1.2084e-07]),
 (1, False): tensor([7.2862e-08]),
 (1, True): tensor([6.9801e-08])}

To test equivariance on a specific graph, args_in can be used:

    irreps_in=['cartesian_points', f.irreps_in],
    args_in=[my_pos, my_x],
{(0, False): tensor([4.0036e-07]),
 (0, True): tensor([5.0912e-07]),
 (1, False): tensor([5.3644e-07]),
 (1, True): tensor([1.6689e-06])}


assert_equivariant also logs the equivariance error to the e3nn.util.test logger with level INFO regardless of whether the test fails. When running in pytest, these logs can be seen using the “Live Logs” feature:

pytest tests/ --log-cli-level info