Model Gate of January 2021

Multipurpose equivariant neural network for point-clouds. Made with e3nn.o3.TensorProduct for the linear part and e3nn.nn.Gate for the nonlinearities.

Convolution

The linear part, module Convolution, is inspired from the Depth wise Separable Convolution idea. The main operation of the Convolution module is tp. It makes the atoms interact with their neighbors but does not mix the channels. To mix the channels, it is sandwiched between lin1 and lin2.

    index = index.reshape(-1, 1).expand_as(src)
    return out.scatter_add_(0, index, src)


def radius_graph(pos, r_max, batch) -> torch.Tensor:
    # naive and inefficient version of torch_cluster.radius_graph
    r = torch.cdist(pos, pos)
    index = ((r < r_max) & (r > 0)).nonzero().T
    index = index[:, batch[index[0]] == batch[index[1]]]
    return index


@compile_mode("script")
class Convolution(torch.nn.Module):
    r"""equivariant convolution

    Parameters
    ----------
    irreps_in : `e3nn.o3.Irreps`
        representation of the input node features

    irreps_node_attr : `e3nn.o3.Irreps`
        representation of the node attributes

    irreps_edge_attr : `e3nn.o3.Irreps`
        representation of the edge attributes

    irreps_out : `e3nn.o3.Irreps` or None
        representation of the output node features

    number_of_basis : int
        number of basis on which the edge length are projected

    radial_layers : int
        number of hidden layers in the radial fully connected network

    radial_neurons : int
        number of neurons in the hidden layers of the radial fully connected network

    num_neighbors : float
        typical number of nodes convolved over
    """

    def __init__(
        self,
        irreps_in,
        irreps_node_attr,
        irreps_edge_attr,
        irreps_out,
        number_of_basis,
        radial_layers,
        radial_neurons,
        num_neighbors,
    ) -> None:
        super().__init__()
        self.irreps_in = o3.Irreps(irreps_in)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr)
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)
        self.irreps_out = o3.Irreps(irreps_out)
        self.num_neighbors = num_neighbors

        self.sc = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_out)

        self.lin1 = FullyConnectedTensorProduct(self.irreps_in, self.irreps_node_attr, self.irreps_in)

        irreps_mid = []
        instructions = []
        for i, (mul, ir_in) in enumerate(self.irreps_in):
            for j, (_, ir_edge) in enumerate(self.irreps_edge_attr):
                for ir_out in ir_in * ir_edge:
                    if ir_out in self.irreps_out:
                        k = len(irreps_mid)
                        irreps_mid.append((mul, ir_out))
                        instructions.append((i, j, k, "uvu", True))
        irreps_mid = o3.Irreps(irreps_mid)
        irreps_mid, p, _ = irreps_mid.sort()

        instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions]

        tp = TensorProduct(
            self.irreps_in,
            self.irreps_edge_attr,
            irreps_mid,
            instructions,
            internal_weights=False,
            shared_weights=False,
        )
        self.fc = FullyConnectedNet(
            [number_of_basis] + radial_layers * [radial_neurons] + [tp.weight_numel], torch.nn.functional.silu
        )
        self.tp = tp

        self.lin2 = FullyConnectedTensorProduct(irreps_mid, self.irreps_node_attr, self.irreps_out)

    def forward(self, node_input, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedded) -> torch.Tensor:
        weight = self.fc(edge_length_embedded)

        x = node_input

Network

The network is a simple succession of Convolution and e3nn.nn.Gate. The activation function is ReLU when dealing with even scalars and tanh of abs when dealing with even scalars. When the parities (p in e3nn.o3.Irrep) are provided, network is equivariant to O(3). To relax this constraint and make it equivariant to SO(3) only, one can simply pass all the irreps parameters to be even (p=1 in e3nn.o3.Irrep). This is why irreps_sh is a parameter of the class Network, one can use specific l of the spherical harmonics with the correct parity p=(-1)^l (one can use e3nn.o3.Irreps.spherical_harmonics for that) or consider that p=1 in order to not be equivariant to parity.

    def __init__(self, first, second) -> None:
        super().__init__()
        self.first = first
        self.second = second
        self.irreps_in = self.first.irreps_in
        self.irreps_out = self.second.irreps_out

    def forward(self, *input):
        x = self.first(*input)
        return self.second(x)


class Network(torch.nn.Module):
    r"""equivariant neural network

    Parameters
    ----------
    irreps_in : `e3nn.o3.Irreps` or None
        representation of the input features
        can be set to ``None`` if nodes don't have input features

    irreps_hidden : `e3nn.o3.Irreps`
        representation of the hidden features

    irreps_out : `e3nn.o3.Irreps`
        representation of the output features

    irreps_node_attr : `e3nn.o3.Irreps` or None
        representation of the nodes attributes
        can be set to ``None`` if nodes don't have attributes

    irreps_edge_attr : `e3nn.o3.Irreps`
        representation of the edge attributes
        the edge attributes are :math:`h(r) Y(\vec r / r)`
        where :math:`h` is a smooth function that goes to zero at ``max_radius``
        and :math:`Y` are the spherical harmonics polynomials

    layers : int
        number of gates (non linearities)

    max_radius : float
        maximum radius for the convolution

    number_of_basis : int
        number of basis on which the edge length are projected

    radial_layers : int
        number of hidden layers in the radial fully connected network

    radial_neurons : int
        number of neurons in the hidden layers of the radial fully connected network

    num_neighbors : float
        typical number of nodes at a distance ``max_radius``

    num_nodes : float
        typical number of nodes in a graph
    """

    def __init__(
        self,
        irreps_in: Optional[o3.Irreps],
        irreps_hidden: o3.Irreps,
        irreps_out: o3.Irreps,
        irreps_node_attr: o3.Irreps,
        irreps_edge_attr: Optional[o3.Irreps],
        layers: int,
        max_radius: float,
        number_of_basis: int,
        radial_layers: int,
        radial_neurons: int,
        num_neighbors: float,
        num_nodes: float,
        reduce_output: bool = True,
    ) -> None:
        super().__init__()
        self.max_radius = max_radius
        self.number_of_basis = number_of_basis
        self.num_neighbors = num_neighbors
        self.num_nodes = num_nodes
        self.reduce_output = reduce_output

        self.irreps_in = o3.Irreps(irreps_in) if irreps_in is not None else None
        self.irreps_hidden = o3.Irreps(irreps_hidden)
        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr) if irreps_node_attr is not None else o3.Irreps("0e")
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)

        self.input_has_node_in = irreps_in is not None
        self.input_has_node_attr = irreps_node_attr is not None

        irreps = self.irreps_in if self.irreps_in is not None else o3.Irreps("0e")

        act = {
            1: torch.nn.functional.silu,
            -1: torch.tanh,
        }
        act_gates = {
            1: torch.sigmoid,
            -1: torch.tanh,
        }

        self.layers = torch.nn.ModuleList()

        for _ in range(layers):
            irreps_scalars = o3.Irreps(
                [
                    (mul, ir)
                    for mul, ir in self.irreps_hidden
                    if ir.l == 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)
                ]
            )
            irreps_gated = o3.Irreps(
                [(mul, ir) for mul, ir in self.irreps_hidden if ir.l > 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)]
            )
            ir = "0e" if tp_path_exists(irreps, self.irreps_edge_attr, "0e") else "0o"
            irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated])

            gate = Gate(
                irreps_scalars,
                [act[ir.p] for _, ir in irreps_scalars],  # scalar
                irreps_gates,
                [act_gates[ir.p] for _, ir in irreps_gates],  # gates (scalars)
                irreps_gated,  # gated tensors
            )
            conv = Convolution(
                irreps,
                self.irreps_node_attr,
                self.irreps_edge_attr,
                gate.irreps_in,
                number_of_basis,
                radial_layers,
                radial_neurons,
                num_neighbors,
            )
            irreps = gate.irreps_out
            self.layers.append(Compose(conv, gate))

        self.layers.append(
            Convolution(
                irreps,
                self.irreps_node_attr,
                self.irreps_edge_attr,
                self.irreps_out,
                number_of_basis,
                radial_layers,
                radial_neurons,
                num_neighbors,
            )
        )

    def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
        """evaluate the network

        Parameters
        ----------
        data : `torch_geometric.data.Data` or dict
            data object containing
            - ``pos`` the position of the nodes (atoms)
            - ``x`` the input features of the nodes, optional
            - ``z`` the attributes of the nodes, for instance the atom type, optional
            - ``batch`` the graph to which the node belong, optional
        """
        if "batch" in data:
            batch = data["batch"]
        else:
            batch = data["pos"].new_zeros(data["pos"].shape[0], dtype=torch.long)

        edge_index = radius_graph(data["pos"], self.max_radius, batch)
        edge_src = edge_index[0]
        edge_dst = edge_index[1]
        edge_vec = data["pos"][edge_src] - data["pos"][edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization="component")
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedded = soft_one_hot_linspace(
            x=edge_length, start=0.0, end=self.max_radius, number=self.number_of_basis, basis="gaussian", cutoff=False
        ).mul(self.number_of_basis**0.5)
        edge_attr = smooth_cutoff(edge_length / self.max_radius)[:, None] * edge_sh

        if self.input_has_node_in and "x" in data:
            assert self.irreps_in is not None

model with self-interactions and gates

Exact equivariance to \(E(3)\)

version of january 2021

Classes:

Compose(first, second)

Convolution(irreps_in, irreps_node_attr, ...)

equivariant convolution

Network(irreps_in, irreps_hidden, ...[, ...])

equivariant neural network

class e3nn.nn.models.gate_points_2101.Compose(first, second)[source]

Bases: Module

Methods:

forward(*input)

Defines the computation performed at every call.

forward(*input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class e3nn.nn.models.gate_points_2101.Convolution(irreps_in, irreps_node_attr, irreps_edge_attr, irreps_out, number_of_basis, radial_layers, radial_neurons, num_neighbors)[source]

Bases: Module

equivariant convolution

Parameters:
  • irreps_in (e3nn.o3.Irreps) – representation of the input node features

  • irreps_node_attr (e3nn.o3.Irreps) – representation of the node attributes

  • irreps_edge_attr (e3nn.o3.Irreps) – representation of the edge attributes

  • irreps_out (e3nn.o3.Irreps or None) – representation of the output node features

  • number_of_basis (int) – number of basis on which the edge length are projected

  • radial_layers (int) – number of hidden layers in the radial fully connected network

  • radial_neurons (int) – number of neurons in the hidden layers of the radial fully connected network

  • num_neighbors (float) – typical number of nodes convolved over

Methods:

forward(node_input, node_attr, edge_src, ...)

Defines the computation performed at every call.

forward(node_input, node_attr, edge_src, edge_dst, edge_attr, edge_length_embedded) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class e3nn.nn.models.gate_points_2101.Network(irreps_in: Irreps | None, irreps_hidden: Irreps, irreps_out: Irreps, irreps_node_attr: Irreps, irreps_edge_attr: Irreps | None, layers: int, max_radius: float, number_of_basis: int, radial_layers: int, radial_neurons: int, num_neighbors: float, num_nodes: float, reduce_output: bool = True)[source]

Bases: Module

equivariant neural network

Parameters:
  • irreps_in (e3nn.o3.Irreps or None) – representation of the input features can be set to None if nodes don’t have input features

  • irreps_hidden (e3nn.o3.Irreps) – representation of the hidden features

  • irreps_out (e3nn.o3.Irreps) – representation of the output features

  • irreps_node_attr (e3nn.o3.Irreps or None) – representation of the nodes attributes can be set to None if nodes don’t have attributes

  • irreps_edge_attr (e3nn.o3.Irreps) – representation of the edge attributes the edge attributes are \(h(r) Y(\vec r / r)\) where \(h\) is a smooth function that goes to zero at max_radius and \(Y\) are the spherical harmonics polynomials

  • layers (int) – number of gates (non linearities)

  • max_radius (float) – maximum radius for the convolution

  • number_of_basis (int) – number of basis on which the edge length are projected

  • radial_layers (int) – number of hidden layers in the radial fully connected network

  • radial_neurons (int) – number of neurons in the hidden layers of the radial fully connected network

  • num_neighbors (float) – typical number of nodes at a distance max_radius

  • num_nodes (float) – typical number of nodes in a graph

Methods:

forward(data)

evaluate the network

forward(data: Dict[str, Tensor]) Tensor[source]

evaluate the network

Parameters:

data (torch_geometric.data.Data or dict) – data object containing - pos the position of the nodes (atoms) - x the input features of the nodes, optional - z the attributes of the nodes, for instance the atom type, optional - batch the graph to which the node belong, optional