# Model Gate of January 2021

Multipurpose equivariant neural network for point-clouds. Made with e3nn.o3.TensorProduct for the linear part and e3nn.nn.Gate for the nonlinearities.

Convolution

The linear part, module Convolution, is inspired from the Depth wise Separable Convolution idea. The main operation of the Convolution module is tp. It makes the atoms interact with their neighbors but does not mix the channels. To mix the channels, it is sandwiched between lin1 and lin2.

@compile_mode('script')
class Convolution(torch.nn.Module):
r"""equivariant convolution

Parameters
----------
irreps_in : e3nn.o3.Irreps
representation of the input node features

irreps_node_attr : e3nn.o3.Irreps
representation of the node attributes

irreps_edge_attr : e3nn.o3.Irreps
representation of the edge attributes

irreps_out : e3nn.o3.Irreps or None
representation of the output node features

number_of_basis : int
number of basis on which the edge length are projected

number of hidden layers in the radial fully connected network

number of neurons in the hidden layers of the radial fully connected network

num_neighbors : float
typical number of nodes convolved over
"""
def __init__(
self,
irreps_in,
irreps_node_attr,
irreps_edge_attr,
irreps_out,
number_of_basis,
num_neighbors
) -> None:
super().__init__()
self.irreps_in = o3.Irreps(irreps_in)
self.irreps_node_attr = o3.Irreps(irreps_node_attr)
self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
self.irreps_out = o3.Irreps(irreps_out)
self.num_neighbors = num_neighbors

self.sc = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_out)

self.lin1 = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_in)

irreps_mid = []
instructions = []
for i, (mul, ir_in) in enumerate(self.irreps_in):
for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
for ir_out in ir_in * ir_edge:
if ir_out in self.irreps_out:
k = len(irreps_mid)
irreps_mid.append((mul, ir_out))
instructions.append((i, j, k, 'uvu', True))
irreps_mid = o3.Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort()

instructions = [
(i_1, i_2, p[i_out], mode, train)
for i_1, i_2, i_out, mode, train in instructions
]

tp = TensorProduct(
self.irreps_in,
self.irreps_edge_attr,
irreps_mid,
instructions,
internal_weights=False,
shared_weights=False,
)
self.tp = tp

self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_out)

def forward(self, node_input, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedded) -> torch.Tensor:
weight = self.fc(edge_length_embedded)

x = node_input

s = self.sc(x, node_attr)
x = self.lin1(x, node_attr)

edge_features = self.tp(x[edge_src], edge_attr, weight)
x = scatter(edge_features, edge_dst, dim=0, dim_size=x.shape[0]).div(self.num_neighbors**0.5)

x = self.lin2(x, node_attr)

c_s, c_x = math.sin(math.pi / 8), math.cos(math.pi / 8)
c_x = (1 - m) + c_x * m
return c_s * s + c_x * x


Network

The network is a simple succession of Convolution and e3nn.nn.Gate. The activation function is ReLU when dealing with even scalars and tanh of abs when dealing with even scalars. When the parities (p in e3nn.o3.Irrep) are provided, network is equivariant to O(3). To relax this constraint and make it equivariant to SO(3) only, one can simply pass all the irreps parameters to be even (p=1 in e3nn.o3.Irrep). This is why irreps_sh is a parameter of the class Network, one can use specific l of the spherical harmonics with the correct parity p=(-1)^l (one can use e3nn.o3.Irreps.spherical_harmonics for that) or consider that p=1 in order to not be equivariant to parity.

class Network(torch.nn.Module):
r"""equivariant neural network

Parameters
----------
irreps_in : e3nn.o3.Irreps or None
representation of the input features
can be set to None if nodes don't have input features

irreps_hidden : e3nn.o3.Irreps
representation of the hidden features

irreps_out : e3nn.o3.Irreps
representation of the output features

irreps_node_attr : e3nn.o3.Irreps or None
representation of the nodes attributes
can be set to None if nodes don't have attributes

irreps_edge_attr : e3nn.o3.Irreps
representation of the edge attributes
the edge attributes are :math:h(r) Y(\vec r / r)
where :math:h is a smooth function that goes to zero at max_radius
and :math:Y are the spherical harmonics polynomials

layers : int
number of gates (non linearities)

number_of_basis : int
number of basis on which the edge length are projected

number of hidden layers in the radial fully connected network

number of neurons in the hidden layers of the radial fully connected network

num_neighbors : float
typical number of nodes at a distance max_radius

num_nodes : float
typical number of nodes in a graph
"""
def __init__(
self,
irreps_in,
irreps_hidden,
irreps_out,
irreps_node_attr,
irreps_edge_attr,
layers,
number_of_basis,
num_neighbors,
num_nodes,
reduce_output=True,
) -> None:
super().__init__()
self.number_of_basis = number_of_basis
self.num_neighbors = num_neighbors
self.num_nodes = num_nodes
self.reduce_output = reduce_output

self.irreps_in = o3.Irreps(irreps_in) if irreps_in is not None else None
self.irreps_hidden = o3.Irreps(irreps_hidden)
self.irreps_out = o3.Irreps(irreps_out)
self.irreps_node_attr = o3.Irreps(irreps_node_attr) if irreps_node_attr is not None else o3.Irreps("0e")
self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)

self.input_has_node_in = (irreps_in is not None)
self.input_has_node_attr = (irreps_node_attr is not None)

irreps = self.irreps_in if self.irreps_in is not None else o3.Irreps("0e")

act = {
1: torch.nn.functional.silu,
-1: torch.tanh,
}
act_gates = {
1: torch.sigmoid,
-1: torch.tanh,
}

self.layers = torch.nn.ModuleList()

for _ in range(layers):
irreps_scalars = o3.Irreps([(mul, ir) for mul, ir in self.irreps_hidden if ir.l == 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)])
irreps_gated = o3.Irreps([(mul, ir) for mul, ir in self.irreps_hidden if ir.l > 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)])
ir = "0e" if tp_path_exists(irreps, self.irreps_edge_attr, "0e") else "0o"
irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated])

gate = Gate(
irreps_scalars, [act[ir.p] for _, ir in irreps_scalars],  # scalar
irreps_gates, [act_gates[ir.p] for _, ir in irreps_gates],  # gates (scalars)
irreps_gated  # gated tensors
)
conv = Convolution(
irreps,
self.irreps_node_attr,
self.irreps_edge_attr,
gate.irreps_in,
number_of_basis,
num_neighbors
)
irreps = gate.irreps_out
self.layers.append(Compose(conv, gate))

self.layers.append(
Convolution(
irreps,
self.irreps_node_attr,
self.irreps_edge_attr,
self.irreps_out,
number_of_basis,
num_neighbors
)
)

def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
"""evaluate the network

Parameters
----------
data : torch_geometric.data.Data or dict
data object containing
- pos the position of the nodes (atoms)
- x the input features of the nodes, optional
- z the attributes of the nodes, for instance the atom type, optional
- batch the graph to which the node belong, optional
"""
if 'batch' in data:
batch = data['batch']
else:
batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long)

edge_src = edge_index[0]
edge_dst = edge_index[1]
edge_vec = data['pos'][edge_src] - data['pos'][edge_dst]
edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization='component')
edge_length = edge_vec.norm(dim=1)
edge_length_embedded = soft_one_hot_linspace(
x=edge_length,
start=0.0,
number=self.number_of_basis,
basis='gaussian',
cutoff=False
).mul(self.number_of_basis**0.5)
edge_attr = smooth_cutoff(edge_length / self.max_radius)[:, None] * edge_sh

if self.input_has_node_in and 'x' in data:
assert self.irreps_in is not None
x = data['x']
else:
assert self.irreps_in is None
x = data['pos'].new_ones((data['pos'].shape[0], 1))

if self.input_has_node_attr and 'z' in data:
z = data['z']
else:
assert self.irreps_node_attr == o3.Irreps("0e")
z = data['pos'].new_ones((data['pos'].shape[0], 1))

for lay in self.layers:
x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded)

if self.reduce_output:
return scatter(x, batch, dim=0).div(self.num_nodes**0.5)
else:
return x


model with self-interactions and gates

Exact equivariance to $$E(3)$$

version of january 2021

Classes:

 Compose(first, second) Convolution(irreps_in, irreps_node_attr, ...) equivariant convolution Network(irreps_in, irreps_hidden, ...[, ...]) equivariant neural network
class e3nn.nn.models.gate_points_2101.Compose(first, second)[source]

Bases: torch.nn.modules.module.Module

Methods:

 forward(*input) Defines the computation performed at every call.
forward(*input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Bases: torch.nn.modules.module.Module

equivariant convolution

Parameters
• irreps_in (e3nn.o3.Irreps) – representation of the input node features

• irreps_node_attr (e3nn.o3.Irreps) – representation of the node attributes

• irreps_edge_attr (e3nn.o3.Irreps) – representation of the edge attributes

• irreps_out (e3nn.o3.Irreps or None) – representation of the output node features

• number_of_basis (int) – number of basis on which the edge length are projected

• radial_layers (int) – number of hidden layers in the radial fully connected network

• radial_neurons (int) – number of neurons in the hidden layers of the radial fully connected network

• num_neighbors (float) – typical number of nodes convolved over

Methods:

 forward(node_input, node_attr, edge_src, ...) Defines the computation performed at every call.
forward(node_input, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedded) [source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Bases: torch.nn.modules.module.Module

equivariant neural network

Parameters
• irreps_in (e3nn.o3.Irreps or None) – representation of the input features can be set to None if nodes don’t have input features

• irreps_hidden (e3nn.o3.Irreps) – representation of the hidden features

• irreps_out (e3nn.o3.Irreps) – representation of the output features

• irreps_node_attr (e3nn.o3.Irreps or None) – representation of the nodes attributes can be set to None if nodes don’t have attributes

• irreps_edge_attr (e3nn.o3.Irreps) – representation of the edge attributes the edge attributes are $$h(r) Y(\vec r / r)$$ where $$h$$ is a smooth function that goes to zero at max_radius and $$Y$$ are the spherical harmonics polynomials

• layers (int) – number of gates (non linearities)

• number_of_basis (int) – number of basis on which the edge length are projected

• radial_layers (int) – number of hidden layers in the radial fully connected network

• radial_neurons (int) – number of neurons in the hidden layers of the radial fully connected network

• num_neighbors (float) – typical number of nodes at a distance max_radius

• num_nodes (float) – typical number of nodes in a graph

Methods:

 forward(data) evaluate the network
forward(data: ) [source]

evaluate the network

Parameters

data (torch_geometric.data.Data or dict) – data object containing - pos the position of the nodes (atoms) - x the input features of the nodes, optional - z the attributes of the nodes, for instance the atom type, optional - batch the graph to which the node belong, optional