Model Gate of January 2021
Multipurpose equivariant neural network for point-clouds.
Made with e3nn.o3.TensorProduct
for the linear part and e3nn.nn.Gate
for the nonlinearities.
Convolution
The linear part, module Convolution
, is inspired from the Depth wise Separable Convolution
idea.
The main operation of the Convolution module is tp
.
It makes the atoms interact with their neighbors but does not mix the channels.
To mix the channels, it is sandwiched between lin1
and lin2
.
index = index.reshape(-1, 1).expand_as(src)
return out.scatter_add_(0, index, src)
def radius_graph(pos, r_max, batch) -> torch.Tensor:
# naive and inefficient version of torch_cluster.radius_graph
r = torch.cdist(pos, pos)
index = ((r < r_max) & (r > 0)).nonzero().T
index = index[:, batch[index[0]] == batch[index[1]]]
return index
@compile_mode("script")
class Convolution(torch.nn.Module):
r"""equivariant convolution
Parameters
----------
irreps_in : `e3nn.o3.Irreps`
representation of the input node features
irreps_node_attr : `e3nn.o3.Irreps`
representation of the node attributes
irreps_edge_attr : `e3nn.o3.Irreps`
representation of the edge attributes
irreps_out : `e3nn.o3.Irreps` or None
representation of the output node features
number_of_basis : int
number of basis on which the edge length are projected
radial_layers : int
number of hidden layers in the radial fully connected network
radial_neurons : int
number of neurons in the hidden layers of the radial fully connected network
num_neighbors : float
typical number of nodes convolved over
"""
def __init__(
self,
irreps_in,
irreps_node_attr,
irreps_edge_attr,
irreps_out,
number_of_basis,
radial_layers,
radial_neurons,
num_neighbors,
) -> None:
super().__init__()
self.irreps_in = o3.Irreps(irreps_in)
self.irreps_node_attr = o3.Irreps(irreps_node_attr)
self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
self.irreps_out = o3.Irreps(irreps_out)
self.num_neighbors = num_neighbors
self.sc = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_out)
self.lin1 = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_in)
irreps_mid = []
instructions = []
for i, (mul, ir_in) in enumerate(self.irreps_in):
for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
for ir_out in ir_in * ir_edge:
if ir_out in self.irreps_out:
k = len(irreps_mid)
irreps_mid.append((mul, ir_out))
instructions.append((i, j, k, "uvu", True))
irreps_mid = o3.Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort()
instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions]
tp = TensorProduct(
self.irreps_in,
self.irreps_edge_attr,
irreps_mid,
instructions,
internal_weights=False,
shared_weights=False,
)
self.fc = FullyConnectedNet(
[number_of_basis] + radial_layers * [radial_neurons] + [tp.weight_numel], torch.nn.functional.silu
)
self.tp = tp
self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_out)
def forward(self, node_input, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedded) -> torch.Tensor:
weight = self.fc(edge_length_embedded)
x = node_input
Network
The network is a simple succession of Convolution
and e3nn.nn.Gate
.
The activation function is ReLU when dealing with even scalars and tanh of abs when dealing with even scalars.
When the parities (p
in e3nn.o3.Irrep
) are provided, network is equivariant to O(3)
.
To relax this constraint and make it equivariant to SO(3)
only, one can simply
pass all the irreps
parameters to be even (p=1
in e3nn.o3.Irrep
).
This is why irreps_sh
is a parameter of the class Network
,
one can use specific l
of the spherical harmonics with the correct parity p=(-1)^l
(one can use e3nn.o3.Irreps.spherical_harmonics
for that)
or consider that p=1
in order to not be equivariant to parity.
def __init__(self, first, second) -> None:
super().__init__()
self.first = first
self.second = second
self.irreps_in = self.first.irreps_in
self.irreps_out = self.second.irreps_out
def forward(self, *input):
x = self.first(*input)
return self.second(x)
class Network(torch.nn.Module):
r"""equivariant neural network
Parameters
----------
irreps_in : `e3nn.o3.Irreps` or None
representation of the input features
can be set to ``None`` if nodes don't have input features
irreps_hidden : `e3nn.o3.Irreps`
representation of the hidden features
irreps_out : `e3nn.o3.Irreps`
representation of the output features
irreps_node_attr : `e3nn.o3.Irreps` or None
representation of the nodes attributes
can be set to ``None`` if nodes don't have attributes
irreps_edge_attr : `e3nn.o3.Irreps`
representation of the edge attributes
the edge attributes are :math:`h(r) Y(\vec r / r)`
where :math:`h` is a smooth function that goes to zero at ``max_radius``
and :math:`Y` are the spherical harmonics polynomials
layers : int
number of gates (non linearities)
max_radius : float
maximum radius for the convolution
number_of_basis : int
number of basis on which the edge length are projected
radial_layers : int
number of hidden layers in the radial fully connected network
radial_neurons : int
number of neurons in the hidden layers of the radial fully connected network
num_neighbors : float
typical number of nodes at a distance ``max_radius``
num_nodes : float
typical number of nodes in a graph
"""
def __init__(
self,
irreps_in: Optional[o3.Irreps],
irreps_hidden: o3.Irreps,
irreps_out: o3.Irreps,
irreps_node_attr: o3.Irreps,
irreps_edge_attr: Optional[o3.Irreps],
layers: int,
max_radius: float,
number_of_basis: int,
radial_layers: int,
radial_neurons: int,
num_neighbors: float,
num_nodes: float,
reduce_output: bool = True,
) -> None:
super().__init__()
self.max_radius = max_radius
self.number_of_basis = number_of_basis
self.num_neighbors = num_neighbors
self.num_nodes = num_nodes
self.reduce_output = reduce_output
self.irreps_in = o3.Irreps(irreps_in) if irreps_in is not None else None
self.irreps_hidden = o3.Irreps(irreps_hidden)
self.irreps_out = o3.Irreps(irreps_out)
self.irreps_node_attr = o3.Irreps(irreps_node_attr) if irreps_node_attr is not None else o3.Irreps("0e")
self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
self.input_has_node_in = irreps_in is not None
self.input_has_node_attr = irreps_node_attr is not None
irreps = self.irreps_in if self.irreps_in is not None else o3.Irreps("0e")
act = {
1: torch.nn.functional.silu,
-1: torch.tanh,
}
act_gates = {
1: torch.sigmoid,
-1: torch.tanh,
}
self.layers = torch.nn.ModuleList()
for _ in range(layers):
irreps_scalars = o3.Irreps(
[
(mul, ir)
for mul, ir in self.irreps_hidden
if ir.l == 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)
]
)
irreps_gated = o3.Irreps(
[(mul, ir) for mul, ir in self.irreps_hidden if ir.l > 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)]
)
ir = "0e" if tp_path_exists(irreps, self.irreps_edge_attr, "0e") else "0o"
irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated])
gate = Gate(
irreps_scalars,
[act[ir.p] for _, ir in irreps_scalars], # scalar
irreps_gates,
[act_gates[ir.p] for _, ir in irreps_gates], # gates (scalars)
irreps_gated, # gated tensors
)
conv = Convolution(
irreps,
self.irreps_node_attr,
self.irreps_edge_attr,
gate.irreps_in,
number_of_basis,
radial_layers,
radial_neurons,
num_neighbors,
)
irreps = gate.irreps_out
self.layers.append(Compose(conv, gate))
self.layers.append(
Convolution(
irreps,
self.irreps_node_attr,
self.irreps_edge_attr,
self.irreps_out,
number_of_basis,
radial_layers,
radial_neurons,
num_neighbors,
)
)
def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
"""evaluate the network
Parameters
----------
data : `torch_geometric.data.Data` or dict
data object containing
- ``pos`` the position of the nodes (atoms)
- ``x`` the input features of the nodes, optional
- ``z`` the attributes of the nodes, for instance the atom type, optional
- ``batch`` the graph to which the node belong, optional
"""
if "batch" in data:
batch = data["batch"]
else:
batch = data["pos"].new_zeros(data["pos"].shape[0], dtype=torch.long)
edge_index = radius_graph(data["pos"], self.max_radius, batch)
edge_src = edge_index[0]
edge_dst = edge_index[1]
edge_vec = data["pos"][edge_src] - data["pos"][edge_dst]
edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization="component")
edge_length = edge_vec.norm(dim=1)
edge_length_embedded = soft_one_hot_linspace(
x=edge_length, start=0.0, end=self.max_radius, number=self.number_of_basis, basis="gaussian", cutoff=False
).mul(self.number_of_basis**0.5)
edge_attr = smooth_cutoff(edge_length / self.max_radius)[:, None] * edge_sh
if self.input_has_node_in and "x" in data:
assert self.irreps_in is not None
model with self-interactions and gates
Exact equivariance to \(E(3)\)
version of january 2021
Classes:
|
|
|
equivariant convolution |
|
equivariant neural network |
- class e3nn.nn.models.gate_points_2101.Compose(first, second)[source]
Bases:
Module
Methods:
forward
(*input)Define the computation performed at every call.
- forward(*input)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class e3nn.nn.models.gate_points_2101.Convolution(irreps_in, irreps_node_attr, irreps_edge_attr, irreps_out, number_of_basis, radial_layers, radial_neurons, num_neighbors)[source]
Bases:
Module
equivariant convolution
- Parameters:
irreps_in (
e3nn.o3.Irreps
) – representation of the input node featuresirreps_node_attr (
e3nn.o3.Irreps
) – representation of the node attributesirreps_edge_attr (
e3nn.o3.Irreps
) – representation of the edge attributesirreps_out (
e3nn.o3.Irreps
or None) – representation of the output node featuresnumber_of_basis (int) – number of basis on which the edge length are projected
radial_layers (int) – number of hidden layers in the radial fully connected network
radial_neurons (int) – number of neurons in the hidden layers of the radial fully connected network
num_neighbors (float) – typical number of nodes convolved over
Methods:
forward
(node_input, node_attr, edge_src, ...)Define the computation performed at every call.
- forward(node_input, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedded) Tensor [source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class e3nn.nn.models.gate_points_2101.Network(irreps_in: Irreps | None, irreps_hidden: Irreps, irreps_out: Irreps, irreps_node_attr: Irreps, irreps_edge_attr: Irreps | None, layers: int, max_radius: float, number_of_basis: int, radial_layers: int, radial_neurons: int, num_neighbors: float, num_nodes: float, reduce_output: bool = True)[source]
Bases:
Module
equivariant neural network
- Parameters:
irreps_in (
e3nn.o3.Irreps
or None) – representation of the input features can be set toNone
if nodes don’t have input featuresirreps_hidden (
e3nn.o3.Irreps
) – representation of the hidden featuresirreps_out (
e3nn.o3.Irreps
) – representation of the output featuresirreps_node_attr (
e3nn.o3.Irreps
or None) – representation of the nodes attributes can be set toNone
if nodes don’t have attributesirreps_edge_attr (
e3nn.o3.Irreps
) – representation of the edge attributes the edge attributes are \(h(r) Y(\vec r / r)\) where \(h\) is a smooth function that goes to zero atmax_radius
and \(Y\) are the spherical harmonics polynomialslayers (int) – number of gates (non linearities)
max_radius (float) – maximum radius for the convolution
number_of_basis (int) – number of basis on which the edge length are projected
radial_layers (int) – number of hidden layers in the radial fully connected network
radial_neurons (int) – number of neurons in the hidden layers of the radial fully connected network
num_neighbors (float) – typical number of nodes at a distance
max_radius
num_nodes (float) – typical number of nodes in a graph
Methods:
forward
(data)evaluate the network
- forward(data: Dict[str, Tensor]) Tensor [source]
evaluate the network
- Parameters:
data (
torch_geometric.data.Data
or dict) – data object containing -pos
the position of the nodes (atoms) -x
the input features of the nodes, optional -z
the attributes of the nodes, for instance the atom type, optional -batch
the graph to which the node belong, optional