# Point inputs with periodic boundary conditions

This example shows how to give point inputs with periodic boundary conditions
(e.g. crystal data) to a Euclidean neural network built with `e3nn`

. For a specific
application, this code should be modified with a more tailored network design.

```
import torch
import e3nn
import ase
import ase.neighborlist
import torch_geometric
import torch_geometric.data
default_dtype = torch.float64
torch.set_default_dtype(default_dtype)
```

## Example crystal structures

First, we create some crystal structures which have periodic boundary conditions.

```
# A lattice is a 3 x 3 matrix
# The first index is the lattice vector (a, b, c)
# The second index is a Cartesian index over (x, y, z)
# Polonium with Simple Cubic Lattice
po_lattice = torch.eye(3) * 3.340 # Cubic lattice with edges of length 3.34 AA
po_coords = torch.tensor([[0., 0., 0.,]])
po_types = ['Po']
# Silicon with Diamond Structure
si_lattice = torch.tensor([
[0. , 2.734364, 2.734364],
[2.734364, 0. , 2.734364],
[2.734364, 2.734364, 0. ]
])
si_coords = torch.tensor([
[1.367182, 1.367182, 1.367182],
[0. , 0. , 0. ]
])
si_types = ['Si', 'Si']
po = ase.Atoms(symbols=po_types, positions=po_coords, cell=po_lattice, pbc=True)
si = ase.Atoms(symbols=si_types, positions=si_coords, cell=si_lattice, pbc=True)
```

## Create and store periodic graph data

We use the `ase.neighborlist.neighbor_list`

algorithm and a `radial_cutoff`

distance to define which edges to include in the graph to represent
interactions with neighboring atoms. Note that for a convolutional network, the
number of layers determines the receptive field, i.e. how “far out” any given atom
can see. So even if a we use a `radial_cutoff = 3.5`

, a two layer network
effectively sees `2 * 3.5 = 7`

distance units (in this case Angstroms) away and a
three layer network `3 * 3.5 = 10.5`

distance units. We then store our data
in `torch_geometric.data.Data`

objects that we will batch with
`torch_geometric.data.DataLoader`

below.

```
radial_cutoff = 3.5 # Only include edges for neighboring atoms within a radius of 3.5 Angstroms.
type_encoding = {'Po': 0, 'Si': 1}
type_onehot = torch.eye(len(type_encoding))
dataset = []
dummy_energies = torch.randn(2, 1, 1) # dummy energies for example
for crystal, energy in zip([po, si], dummy_energies):
# edge_src and edge_dst are the indices of the central and neighboring atom, respectively
# edge_shift indicates whether the neighbors are in different images / copies of the unit cell
edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=crystal, cutoff=radial_cutoff, self_interaction=True)
data = torch_geometric.data.Data(
pos=torch.tensor(crystal.get_positions()),
lattice=torch.tensor(crystal.cell.array).unsqueeze(0), # We add a dimension for batching
x=type_onehot[[type_encoding[atom] for atom in crystal.symbols]], # Using "dummy" inputs of scalars because they are all C
edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
edge_shift=torch.tensor(edge_shift, dtype=default_dtype),
energy=energy # dummy energy (assumed to be normalized "per atom")
)
dataset.append(data)
print(dataset)
```

```
[Data(x=[1, 2], edge_index=[2, 7], pos=[1, 3], lattice=[1, 3, 3], edge_shift=[7, 3], energy=[1, 1]), Data(x=[2, 2], edge_index=[2, 10], pos=[2, 3], lattice=[1, 3, 3], edge_shift=[10, 3], energy=[1, 1])]
```

The first `torch_geometric.data.Data`

object is for simple cubic Polonium which has 7
edges: 6 for nearest neighbors and 1 as a “self” edge, `6 + 1 = 7`

.
The second `torch_geometric.data.Data`

object is for diamond Silicon which has 10 edges: 4
nearest neighbors for each of the two atoms and 2 “self” edges, one for
each atom, `4 * 2 + 1 * 2 = 10`

. The lattice of each structure has a
shape of `[1, 3, 3]`

such that when we batch examples, the batched
lattices will have shape `[batch_size, 3, 3]`

.

## Graph Batches

`torch_geometric.data.DataLoader`

create batches of
differently sized structures and produces `torch_geometric.data.Data`

objects containing a batch when
iterated over.

```
batch_size = 2
dataloader = torch_geometric.data.DataLoader(dataset, batch_size=batch_size)
for data in dataloader:
print(data)
print(data.batch)
print(data.pos)
print(data.x)
```

```
DataBatch(x=[3, 2], edge_index=[2, 17], pos=[3, 3], lattice=[2, 3, 3], edge_shift=[17, 3], energy=[2, 1], batch=[3], ptr=[3])
tensor([0, 1, 1])
tensor([[0.0000, 0.0000, 0.0000],
[1.3672, 1.3672, 1.3672],
[0.0000, 0.0000, 0.0000]])
tensor([[1., 0.],
[0., 1.],
[0., 1.]])
```

`data.batch`

is the batch index which is tensor of shape
`[batch_size]`

that stores which points or “atoms” belong to which
example. In this case, since we only have two examples in our batch, the batch
tensor only contains the numbers `0`

and `1`

. The batch index is
often passed to `scatter`

operations to aggregate per examples
values,
e.g. the total energy for a single crystal structure.

For more details on batching with `torch_geometric`

, please see this
page.

## Relative distance vectors of edges with periodic boundaries

To calculate the vectors associated with each edge for a given `torch_geometric.data.Data`

object representing a single example, we use the following expression:

```
edge_src, edge_dst = data['edge_index'][0], data['edge_index'][1]
edge_vec = (data['pos'][edge_dst] - data['pos'][edge_src]
+ torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice']))
```

The first line in the definition of `edge_vec`

is simply how one normally computes
relative distance vectors given two points. The second line adds the contribution
to the relative distance vector due to crossing unit cell boundaries i.e.
if atoms belong to different images of the unit cell. As we will see below, we can
modify this expression to also include the `data['batch']`

tensor when handling
batched data.

## One Approach: Adding a Preprocessing Method to the Network

While `edge_vec`

can be stored in the `torch_geometric.data.Data`

object, it can also be calculated
by adding a preprocessing method to the Network. For this example, we create a
modified version of the example network `SimpleNetwork`

documented
here
with source code
here.
`SimpleNetwork`

is a good starting point to check your data pipeline
but should be replaced with a more tailored network for your specific
application.

```
from e3nn.nn.models.v2103.gate_points_networks import SimpleNetwork
from typing import Dict, Union
import torch_scatter
class SimplePeriodicNetwork(SimpleNetwork):
def __init__(self, **kwargs):
"""The keyword `pool_nodes` is used by SimpleNetwork to determine
whether we sum over all atom contributions per example. In this example,
we want use a mean operations instead, so we will override this behavior.
"""
self.pool = False
if kwargs['pool_nodes'] == True:
kwargs['pool_nodes'] = False
kwargs['num_nodes'] = 1.
self.pool = True
super().__init__(**kwargs)
# Overwriting preprocess method of SimpleNetwork to adapt for periodic boundary data
def preprocess(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
if 'batch' in data:
batch = data['batch']
else:
batch = data['pos'].new_zeros(data['pos'].shape[0], dtype=torch.long)
edge_src = data['edge_index'][0] # Edge source
edge_dst = data['edge_index'][1] # Edge destination
# We need to compute this in the computation graph to backprop to positions
# We are computing the relative distances + unit cell shifts from periodic boundaries
edge_batch = batch[edge_src]
edge_vec = (data['pos'][edge_dst]
- data['pos'][edge_src]
+ torch.einsum('ni,nij->nj', data['edge_shift'], data['lattice'][edge_batch]))
return batch, data['x'], edge_src, edge_dst, edge_vec
def forward(self, data: Union[torch_geometric.data.Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
# if pool_nodes was set to True, use scatter_mean to aggregate
output = super().forward(data)
if self.pool == True:
return torch_scatter.scatter_mean(output, data.batch, dim=0) # Take mean over atoms per example
else:
return output
```

We define and run the network.

```
net = SimplePeriodicNetwork(
irreps_in="2x0e", # One hot scalars (L=0 and even parity) on each atom to represent atom type
irreps_out="1x0e", # Single scalar (L=0 and even parity) to output (for example) energy
max_radius=radial_cutoff, # Cutoff radius for convolution
num_neighbors=10.0, # scaling factor based on the typical number of neighbors
pool_nodes=True, # We pool nodes to predict total energy
)
```

When we apply the network to our data, we get one scalar per example.

```
for data in dataloader:
print(net(data).shape) # One scalar per example
```

```
torch.Size([2, 1])
```