Tetris Gate Example
Build on top of Tetris Polynomial Example, the following is added:
code
"""Classify tetris using gate activation function
Implement a equivariant model using gates to fit the tetris dataset
Exact equivariance to :math:`E(3)`
>>> test()
"""
import logging
import torch
from torch_cluster import radius_graph
from torch_geometric.data import Data, DataLoader
from torch_scatter import scatter
from e3nn import o3
from e3nn.nn import FullyConnectedNet, Gate
from e3nn.o3 import FullyConnectedTensorProduct
from e3nn.math import soft_one_hot_linspace
from e3nn.util.test import assert_equivariant
def tetris():
pos = [
[(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)], # chiral_shape_1
[(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2
[(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)], # square
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)], # line
[(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)], # corner
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)], # L
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)], # T
[(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)], # zigzag
]
pos = torch.tensor(pos, dtype=torch.get_default_dtype())
# Since chiral shapes are the mirror of one another we need an *odd* scalar to distinguish them
labels = torch.tensor([
[+1, 0, 0, 0, 0, 0, 0], # chiral_shape_1
[-1, 0, 0, 0, 0, 0, 0], # chiral_shape_2
[0, 1, 0, 0, 0, 0, 0], # square
[0, 0, 1, 0, 0, 0, 0], # line
[0, 0, 0, 1, 0, 0, 0], # corner
[0, 0, 0, 0, 1, 0, 0], # L
[0, 0, 0, 0, 0, 1, 0], # T
[0, 0, 0, 0, 0, 0, 1], # zigzag
], dtype=torch.get_default_dtype())
# apply random rotation
pos = torch.einsum('zij,zaj->zai', o3.rand_matrix(len(pos)), pos)
# put in torch_geometric format
dataset = [Data(pos=pos) for pos in pos]
data = next(iter(DataLoader(dataset, batch_size=len(dataset))))
return data, labels
def mean_std(name, x):
print(f"{name} \t{x.mean():.1f} ± ({x.var(0).mean().sqrt():.1f}|{x.std():.1f})")
class Convolution(torch.nn.Module):
def __init__(self, irreps_in, irreps_sh, irreps_out, num_neighbors) -> None:
super().__init__()
self.num_neighbors = num_neighbors
tp = FullyConnectedTensorProduct(
irreps_in1=irreps_in,
irreps_in2=irreps_sh,
irreps_out=irreps_out,
internal_weights=False,
shared_weights=False,
)
self.fc = FullyConnectedNet([3, 256, tp.weight_numel], torch.relu)
self.tp = tp
self.irreps_out = self.tp.irreps_out
def forward(self, node_features, edge_src, edge_dst, edge_attr, edge_scalars) -> torch.Tensor:
weight = self.fc(edge_scalars)
edge_features = self.tp(node_features[edge_src], edge_attr, weight)
node_features = scatter(edge_features, edge_dst, dim=0).div(self.num_neighbors**0.5)
return node_features
class Network(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.num_neighbors = 3.8 # typical number of neighbors
self.irreps_sh = o3.Irreps.spherical_harmonics(3)
irreps = self.irreps_sh
# First layer with gate
gate = Gate(
"16x0e + 16x0o", [torch.relu, torch.abs], # scalar
"8x0e + 8x0o + 8x0e + 8x0o", [torch.relu, torch.tanh, torch.relu, torch.tanh], # gates (scalars)
"16x1o + 16x1e" # gated tensors, num_irreps has to match with gates
)
self.conv = Convolution(irreps, self.irreps_sh, gate.irreps_in, self.num_neighbors)
self.gate = gate
irreps = self.gate.irreps_out
# Final layer
self.final = Convolution(irreps, self.irreps_sh, "0o + 6x0e", self.num_neighbors)
self.irreps_out = self.final.irreps_out
def forward(self, data) -> torch.Tensor:
num_nodes = 4 # typical number of nodes
edge_src, edge_dst = radius_graph(x=data.pos, r=2.5, batch=data.batch)
edge_vec = data.pos[edge_src] - data.pos[edge_dst]
edge_attr = o3.spherical_harmonics(
l=self.irreps_sh,
x=edge_vec,
normalize=True,
normalization='component'
)
edge_length_embedded = soft_one_hot_linspace(
x=edge_vec.norm(dim=1),
start=0.5,
end=2.5,
number=3,
basis='smooth_finite',
cutoff=True
) * 3**0.5
x = scatter(edge_attr, edge_dst, dim=0).div(self.num_neighbors**0.5)
x = self.conv(x, edge_src, edge_dst, edge_attr, edge_length_embedded)
x = self.gate(x)
x = self.final(x, edge_src, edge_dst, edge_attr, edge_length_embedded)
return scatter(x, data.batch, dim=0).div(num_nodes**0.5)
def main():
data, labels = tetris()
f = Network()
print("Built a model:")
print(f)
optim = torch.optim.Adam(f.parameters(), lr=1e-3)
# == Training ==
for step in range(200):
pred = f(data)
loss = (pred - labels).pow(2).sum()
optim.zero_grad()
loss.backward()
optim.step()
if step % 10 == 0:
accuracy = pred.round().eq(labels).all(dim=1).double().mean(dim=0).item()
print(f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy")
# == Check equivariance ==
# Because the model outputs (psuedo)scalars, we can easily directly
# check its equivariance to the same data with new rotations:
print("Testing equivariance directly...")
rotated_data, _ = tetris()
error = f(rotated_data) - f(data)
print(f"Equivariance error = {error.abs().max().item():.1e}")
print("Testing equivariance using `assert_equivariance`...")
# We can also use the library's `assert_equivariant` helper
# `assert_equivariant` also tests parity and translation, and
# can handle non-(psuedo)scalar outputs.
# To "interpret" between it and torch_geometric, we use a small wrapper:
def wrapper(pos, batch):
return f(Data(pos=pos, batch=batch))
# `assert_equivariant` uses logging to print a summary of the equivariance error,
# so we enable logging
logging.basicConfig(level=logging.INFO)
assert_equivariant(
wrapper,
# We provide the original data that `assert_equivariant` will transform...
args_in=[data.pos, data.batch],
# ...in accordance with these irreps...
irreps_in=[
"cartesian_points", # pos has vector 1o irreps, but is also translation equivariant
None, # `None` indicates invariant, possibly non-floating-point data
],
# ...and confirm that the outputs transform correspondingly for these irreps:
irreps_out=[f.irreps_out],
)
if __name__ == '__main__':
main()
def test():
torch.set_default_dtype(torch.float64)
data, labels = tetris()
f = Network()
pred = f(data)
loss = (pred - labels).pow(2).sum()
loss.backward()
rotated_data, _ = tetris()
error = f(rotated_data) - f(data)
assert error.abs().max() < 1e-10
def profile():
data, labels = tetris()
data = data.to(device='cuda')
labels = labels.to(device='cuda')
f = Network()
f.to(device='cuda')
optim = torch.optim.Adam(f.parameters(), lr=1e-2)
called_num = [0]
def trace_handler(p):
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
p.export_chrome_trace("test_trace_" + str(called_num[0]) + ".json")
called_num[0] += 1
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=50,
warmup=1,
active=1),
on_trace_ready=trace_handler
) as p:
for _ in range(52):
pred = f(data)
loss = (pred - labels).pow(2).sum()
optim.zero_grad()
loss.backward()
optim.step()
p.step()
Full code here