Models of March 2021

Simple Network

Let’s create a simple network and evaluate it on random data.

import torch
from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork

net = SimpleNetwork(
    irreps_in="3x0e + 2x1o",
    irreps_out="1x1o",
    max_radius=2.0,
    num_neighbors=3.0,
    num_nodes=5.0
)

pos = torch.randn(5, 3)
x = net.irreps_in.randn(5, -1)

net({
    'pos': pos,
    'x': x
})
tensor([[ 0.3497, -1.0514, -0.0380]], grad_fn=<DivBackward0>)

If we rotate the inputs,

from e3nn import o3

rot = o3.matrix_x(torch.tensor(3.14 / 3.0))
rot
tensor([[ 1.0000,  0.0000,  0.0000],
        [ 0.0000,  0.5005, -0.8658],
        [ 0.0000,  0.8658,  0.5005]])
net({
    'pos': pos @ rot.T,
    'x': x @ net.irreps_in.D_from_matrix(rot).T
})
tensor([[ 0.3497, -0.4933, -0.9293]], grad_fn=<DivBackward0>)

it gives the same result as rotating the outputs.

net({
    'pos': pos,
    'x': x
}) @ net.irreps_out.D_from_matrix(rot).T
tensor([[ 0.3497, -0.4933, -0.9293]], grad_fn=<MmBackward0>)

Network for a graph with node/edge attributes

graph of the data

A graph is made of nodes and edges. The nodes and edges can have attributes. Usually their only attributes are the positions of the nodes \(\vec r_i\) and the relative positions of the edges \(\vec r_i - \vec r_j\). We typically don’t use the node positions because they change with the global translation of the graph. The nodes and edges can have other attributes like for instance atom type or bond type and so on.

The attributes defines the graph properties. They don’t change layer after layer (in this example). The data (node_input) flow through this graph layer after layer.

In the following network, the edges attributes are the spherical harmonics \(Y^l(\vec r_i - \vec r_j)\) plus the extra attributes provided by the user.

from e3nn.nn.models.v2103.gate_points_networks import NetworkForAGraphWithAttributes
from torch_cluster import radius_graph

max_radius = 3.0

net = NetworkForAGraphWithAttributes(
    irreps_node_input="0e+1e",
    irreps_node_attr="0e+1e",
    irreps_edge_attr="0e+1e",  # attributes in extra of the spherical harmonics
    irreps_node_output="0e+1e",
    max_radius=max_radius,
    num_neighbors=4.0,
    num_nodes=5.0,
)

num_nodes = 5
pos = torch.randn(num_nodes, 4)
edge_index = radius_graph(pos, max_radius)
num_edges = edge_index.shape[1]

net({
    'pos': pos,
    'edge_index': edge_index,
    'node_input': torch.randn(num_nodes, 4),
    'node_attr': torch.randn(num_nodes, 4),
    'edge_attr': torch.randn(num_edges, 4),
})
tensor([[-0.6150, -0.5805, -1.5379, -0.3471]], grad_fn=<DivBackward0>)