Convolution

In this document we will implement an equivariant convolution with e3nn. We will implement this formula:

\[f'_i = \frac{1}{\sqrt{z}} \sum_{j \in \partial(i)} \; f_j \; \otimes\!(h(\|x_{ij}\|)) \; Y(x_{ij} / \|x_{ij}\|)\]

where

  • \(f_j, f'_i\) are the nodes input and output

  • \(z\) is the average degree of the nodes

  • \(\partial(i)\) is the set of neighbors of the node \(i\)

  • \(x_{ij}\) is the relative vector

  • \(h\) is a multi layer perceptron

  • \(Y\) is the spherical harmonics

  • \(x \; \otimes\!(w) \; y\) is a tensor product of \(x\) with \(y\) parametrized by some weights \(w\)

Boilerplate imports

import torch
from torch_cluster import radius_graph
from torch_scatter import scatter
from e3nn import o3, nn
from e3nn.math import soft_one_hot_linspace
import matplotlib.pyplot as plt

Let’s first define the irreps of the input and output features.

irreps_input = o3.Irreps("10x0e + 10x1e")
irreps_output = o3.Irreps("20x0e + 10x1e")

And create a random graph using random positions and edges when the relative distance is smaller than max_radius.

# create node positions
num_nodes = 100
pos = torch.randn(num_nodes, 3)  # random node positions

# create edges
max_radius = 1.8
edge_src, edge_dst = radius_graph(pos, max_radius, max_num_neighbors=num_nodes - 1)

print(edge_src.shape)

edge_vec = pos[edge_dst] - pos[edge_src]

# compute z
num_neighbors = len(edge_src) / num_nodes
num_neighbors
torch.Size([3322])
33.22

edge_src and edge_dst contain the indices of the nodes for each edge. And we can also create some random input features.

f_in = irreps_input.randn(num_nodes, -1)

Note that out data is generated with a normal distribution. We will take care of having all the data following the component normalization (see Normalization).

f_in.pow(2).mean()  # should be close to 1
tensor(0.9820)

Let’s start with

\[Y(x_{ij} / \|x_{ij}\|)\]
irreps_sh = o3.Irreps.spherical_harmonics(lmax=2)
print(irreps_sh)

sh = o3.spherical_harmonics(irreps_sh, edge_vec, normalize=True, normalization='component')
# normalize=True ensure that x is divided by |x| before computing the sh

sh.pow(2).mean()  # should be close to 1
1x0e+1x1o+1x2e
tensor(1.)

Now we need to compute \(\otimes(w)\) and \(h\). Let’s create the tensor product first, it will tell us how many weights it needs.

tp = o3.FullyConnectedTensorProduct(irreps_input, irreps_sh, irreps_output, shared_weights=False)

print(f"{tp} needs {tp.weight_numel} weights")

tp.visualize();
FullyConnectedTensorProduct(10x0e+10x1e x 1x0e+1x1o+1x2e -> 20x0e+10x1e | 400 paths | 400 weights) needs 400 weights
../_images/convolution_6_2.png

in this particual choice of irreps we can see that the l=1 component of the spherical harmonics cannot be used in the tensor product. In this example it’s the equivariance to inversion that prohibit the use of l=1. If we don’t want the equivariance to inversion we can declare all irreps to be even (irreps_sh = Irreps("0e + 1e + 2e")).

To implement \(h\) that has to map the relative distances to the weights of the tensor product we will embed the distances using a basis function and then feed this embedding to a neural network. Let’s create that embedding. Here is the base functions we will use:

num_basis = 10

x = torch.linspace(0.0, 2.0, 1000)
y = soft_one_hot_linspace(
    x,
    start=0.0,
    end=max_radius,
    number=num_basis,
    basis='smooth_finite',
    cutoff=True,
)

plt.plot(x, y);
../_images/convolution_7_0.png

Note that this set of functions are all smooth and are strictly zero beyond max_radius. This is useful to get a convolution that is smooth although the sharp cutoff at max_radius.

Let’s use this embedding for the edge distances and normalize it properly (component i.e. second moment close to 1).

edge_length_embedding = soft_one_hot_linspace(
    edge_vec.norm(dim=1),
    start=0.0,
    end=max_radius,
    number=num_basis,
    basis='smooth_finite',
    cutoff=True,
)
edge_length_embedding = edge_length_embedding.mul(num_basis**0.5)

print(edge_length_embedding.shape)
edge_length_embedding.pow(2).mean()  # the second moment
torch.Size([3322, 10])
tensor(0.8919)

Now we can create a MLP and feed it

fc = nn.FullyConnectedNet([num_basis, 16, tp.weight_numel], torch.relu)
weight = fc(edge_length_embedding)

print(weight.shape)
print(len(edge_src), tp.weight_numel)

# For a proper notmalization, the weights also need to be mean 0
print(weight.mean(), weight.std())  # should close to 0 and 1
torch.Size([3322, 400])
3322 400
tensor(0.0040, grad_fn=<MeanBackward0>) tensor(1.1217, grad_fn=<StdBackward0>)

Now we can compute the term

\[f_j \; \otimes\!(h(\|x_{ij}\|)) \; Y(x_{ij} / \|x_{ij}\|)\]

The idea is to compute this quantity per edges, so we will need to “lift” the input feature to the edges. For that we use edge_src that contains, for each edge, the index of the source node.

summand = tp(f_in[edge_src], sh, weight)

print(summand.shape)
print(summand.pow(2).mean())  # should be close to 1
torch.Size([3322, 50])
tensor(1.1945, grad_fn=<MeanBackward0>)

Only the sum over the neighbors is remaining

\[f'_i = \frac{1}{\sqrt{z}} \sum_{j \in \partial(i)} \; f_j \; \otimes\!(h(\|x_{ij}\|)) \; Y(x_{ij} / \|x_{ij}\|)\]
f_out = scatter(summand, edge_dst, dim=0, dim_size=num_nodes)

f_out = f_out.div(num_neighbors**0.5)

f_out.pow(2).mean()  # should be close to 1
tensor(1.2428, grad_fn=<MeanBackward0>)

Now we can put everything into a function

def conv(f_in, pos):
    edge_src, edge_dst = radius_graph(pos, max_radius, max_num_neighbors=len(pos) - 1)
    edge_vec = pos[edge_dst] - pos[edge_src]
    sh = o3.spherical_harmonics(irreps_sh, edge_vec, normalize=True, normalization='component')
    emb = soft_one_hot_linspace(edge_vec.norm(dim=1), 0.0, max_radius, num_basis, basis='smooth_finite', cutoff=True).mul(num_basis**0.5)
    return scatter(tp(f_in[edge_src], sh, fc(emb)), edge_dst, dim=0, dim_size=num_nodes).div(num_neighbors**0.5)

Now we can check the equivariance

rot = o3.rand_matrix()
D_in = irreps_input.D_from_matrix(rot)
D_out = irreps_output.D_from_matrix(rot)

# rotate before
f_before = conv(f_in @ D_in.T, pos @ rot.T)

# rotate after
f_after = conv(f_in, pos) @ D_out.T

torch.allclose(f_before, f_after, rtol=1e-4, atol=1e-4)
True

The tensor product dominates the execution time:

import time
wall = time.perf_counter()

edge_src, edge_dst = radius_graph(pos, max_radius, max_num_neighbors=len(pos) - 1)
edge_vec = pos[edge_dst] - pos[edge_src]
print(time.perf_counter() - wall); wall = time.perf_counter()

sh = o3.spherical_harmonics(irreps_sh, edge_vec, normalize=True, normalization='component')
print(time.perf_counter() - wall); wall = time.perf_counter()

emb = soft_one_hot_linspace(edge_vec.norm(dim=1), 0.0, max_radius, num_basis, basis='smooth_finite', cutoff=True).mul(num_basis**0.5)
print(time.perf_counter() - wall); wall = time.perf_counter()

weight = fc(emb)
print(time.perf_counter() - wall); wall = time.perf_counter()

summand = tp(f_in[edge_src], sh, weight)
print(time.perf_counter() - wall); wall = time.perf_counter()

scatter(summand, edge_dst, dim=0, dim_size=num_nodes).div(num_neighbors**0.5)
print(time.perf_counter() - wall); wall = time.perf_counter()
0.0014800009994360153
0.0013493509995896602
0.002589572000943008
0.0011077569997723913
0.007198666999101988
0.00042959100028383546