Point inputs with periodic boundary conditions
This example shows how to give point inputs with periodic boundary conditions
(e.g. crystal data) to a Euclidean neural network built with e3nn
. For a specific
application, this code should be modified with a more tailored network design.
import torch
import e3nn
import ase
import ase.neighborlist
import torch_geometric
import torch_geometric.data
default_dtype = torch.float64
torch.set_default_dtype(default_dtype)
Example crystal structures
First, we create some crystal structures which have periodic boundary conditions.
# A lattice is a 3 x 3 matrix
# The first index is the lattice vector (a, b, c)
# The second index is a Cartesian index over (x, y, z)
# Polonium with Simple Cubic Lattice
po_lattice = torch.eye(3) * 3.340 # Cubic lattice with edges of length 3.34 AA
po_coords = torch.tensor([[0., 0., 0.,]])
po_types = ['Po']
# Silicon with Diamond Structure
si_lattice = torch.tensor([
[0. , 2.734364, 2.734364],
[2.734364, 0. , 2.734364],
[2.734364, 2.734364, 0. ]
])
si_coords = torch.tensor([
[1.367182, 1.367182, 1.367182],
[0. , 0. , 0. ]
])
si_types = ['Si', 'Si']
po = ase.Atoms(symbols=po_types, positions=po_coords, cell=po_lattice, pbc=True)
si = ase.Atoms(symbols=si_types, positions=si_coords, cell=si_lattice, pbc=True)
Create and store periodic graph data
We use the ase.neighborlist.neighbor_list
algorithm and a radial_cutoff
distance to define which edges to include in the graph to represent
interactions with neighboring atoms. Note that for a convolutional network, the
number of layers determines the receptive field, i.e. how “far out” any given atom
can see. So even if a we use a radial_cutoff = 3.5
, a two layer network
effectively sees 2 * 3.5 = 7
distance units (in this case Angstroms) away and a
three layer network 3 * 3.5 = 10.5
distance units. We then store our data
in torch_geometric.data.Data
objects that we will batch with
torch_geometric.data.DataLoader
below.
radial_cutoff = 3.5 # Only include edges for neighboring atoms within a radius of 3.5 Angstroms.
type_encoding = {'Po': 0, 'Si': 1}
type_onehot = torch.eye(len(type_encoding))
dataset = []
dummy_energies = torch.randn(2, 1, 1) # dummy energies for example
for crystal, energy in zip([po, si], dummy_energies):
# edge_src and edge_dst are the indices of the central and neighboring atom, respectively
# edge_shift indicates whether the neighbors are in different images / copies of the unit cell
edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=crystal, cutoff=radial_cutoff, self_interaction=True)
data = torch_geometric.data.Data(
pos=torch.tensor(crystal.get_positions()),
lattice=torch.tensor(crystal.cell.array).unsqueeze(0), # We add a dimension for batching
x=type_onehot[[type_encoding[atom] for atom in crystal.symbols]], # Using "dummy" inputs of scalars because they are all C
edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
edge_shift=torch.tensor(edge_shift, dtype=default_dtype),
energy=energy # dummy energy (assumed to be normalized "per atom")
)
dataset.append(data)
print(dataset)
[Data(x=[1, 2], edge_index=[2, 7], pos=[1, 3], lattice=[1, 3, 3], edge_shift=[7, 3], energy=[1, 1]), Data(x=[2, 2], edge_index=[2, 10], pos=[2, 3], lattice=[1, 3, 3], edge_shift=[10, 3], energy=[1, 1])]
The first torch_geometric.data.Data
object is for simple cubic Polonium which has 7
edges: 6 for nearest neighbors and 1 as a “self” edge, 6 + 1 = 7
.
The second torch_geometric.data.Data
object is for diamond Silicon which has 10 edges: 4
nearest neighbors for each of the two atoms and 2 “self” edges, one for
each atom, 4 * 2 + 1 * 2 = 10
. The lattice of each structure has a
shape of [1, 3, 3]
such that when we batch examples, the batched
lattices will have shape [batch_size, 3, 3]
.
Graph Batches
torch_geometric.data.DataLoader
create batches of
differently sized structures and produces torch_geometric.data.Data
objects containing a batch when
iterated over.
batch_size = 2
dataloader = torch_geometric.data.DataLoader(dataset, batch_size=batch_size)
for data in dataloader:
print(data)
print(data.batch)
print(data.pos)
print(data.x)
DataBatch(x=[3, 2], edge_index=[2, 17], pos=[3, 3], lattice=[2, 3, 3], edge_shift=[17, 3], energy=[2, 1], batch=[3], ptr=[3])
tensor([0, 1, 1])
tensor([[0.0000, 0.0000, 0.0000],
[1.3672, 1.3672, 1.3672],
[0.0000, 0.0000, 0.0000]])
tensor([[1., 0.],
[0., 1.],
[0., 1.]])
data.batch
is the batch index which is tensor of shape
[batch_size]
that stores which points or “atoms” belong to which
example. In this case, since we only have two examples in our batch, the batch
tensor only contains the numbers 0
and 1
. The batch index is
often passed to scatter
operations to aggregate per examples
values,
e.g. the total energy for a single crystal structure.
For more details on batching with torch_geometric
, please see this
page.
Relative distance vectors of edges with periodic boundaries
To calculate the vectors associated with each edge for a given torch_geometric.data.Data
object representing a single example, we use the following expression:
edge_src, edge_dst = data['edge_index'][0], data['edge_index'][1]
edge_vec = (data['pos'][edge_dst] - data['pos'][edge_src]
+ torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice']))
The first line in the definition of edge_vec
is simply how one normally computes
relative distance vectors given two points. The second line adds the contribution
to the relative distance vector due to crossing unit cell boundaries i.e.
if atoms belong to different images of the unit cell. As we will see below, we can
modify this expression to also include the data['batch']
tensor when handling
batched data.
One Approach: Adding a Preprocessing Method to the Network
While edge_vec
can be stored in the torch_geometric.data.Data
object, it can also be calculated
by adding a preprocessing method to the Network. For this example, we create a
modified version of the example network SimpleNetwork
documented
here
with source code
here.
SimpleNetwork
is a good starting point to check your data pipeline
but should be replaced with a more tailored network for your specific
application.
from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork
from typing import Dict, Union
import torch_scatter
class SimplePeriodicNetwork(SimpleNetwork):
def __init__(self, **kwargs):
"""The keyword `pool_nodes` is used by SimpleNetwork to determine
whether we sum over all atom contributions per example. In this example,
we want use a mean operations instead, so we will override this behavior.
"""
self.pool = False
if kwargs['pool_nodes'] == True:
kwargs['pool_nodes'] = False
kwargs['num_nodes'] = 1.
self.pool = True
super().__init__(**kwargs)
# Overwriting preprocess method of SimpleNetwork to adapt for periodic boundary data
def preprocess(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
if 'batch' in data:
batch = data['batch']
else:
batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long)
edge_src = data['edge_index'][0] # Edge source
edge_dst = data['edge_index'][1] # Edge destination
# We need to compute this in the computation graph to backprop to positions
# We are computing the relative distances + unit cell shifts from periodic boundaries
edge_batch = batch[edge_src]
edge_vec = (data['pos'][edge_dst]
- data['pos'][edge_src]
+ torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice'][edge_batch]))
return batch, data['x'], edge_src, edge_dst, edge_vec
def forward(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
# if pool_nodes was set to True, use scatter_mean to aggregate
output = super().forward(data)
if self.pool == True:
return torch_scatter.scatter_mean(output, data.batch, dim=0) # Take mean over atoms per example
else:
return output
We define and run the network.
net = SimplePeriodicNetwork(
irreps_in="2x0e", # One hot scalars (L=0 and even parity) on each atom to represent atom type
irreps_out="1x0e", # Single scalar (L=0 and even parity) to output (for example) energy
max_radius=radial_cutoff, # Cutoff radius for convolution
num_neighbors=10.0, # scaling factor based on the typical number of neighbors
pool_nodes=True, # We pool nodes to predict total energy
)
When we apply the network to our data, we get one scalar per example.
for data in dataloader:
print(net(data).shape) # One scalar per example
torch.Size([2, 1])