Tetris Polynomial Example
In this example we create an equivariant polynomial to classify tetris.
We use the following feature of e3nn:
And the following features of pytorch_geometric
the model
return data, labels
class InvariantPolynomial(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.irreps_sh: o3.Irreps = o3.Irreps.spherical_harmonics(3)
irreps_mid = o3.Irreps("64x0e + 24x1e + 24x1o + 16x2e + 16x2o")
irreps_out = o3.Irreps("0o + 6x0e")
self.tp1 = FullyConnectedTensorProduct(
irreps_in1=self.irreps_sh,
irreps_in2=self.irreps_sh,
irreps_out=irreps_mid,
)
self.tp2 = FullyConnectedTensorProduct(
irreps_in1=irreps_mid,
irreps_in2=self.irreps_sh,
irreps_out=irreps_out,
)
self.irreps_out = self.tp2.irreps_out
def forward(self, data) -> torch.Tensor:
num_neighbors = 2 # typical number of neighbors
num_nodes = 4 # typical number of nodes
edge_src, edge_dst = radius_graph(x=data.pos, r=1.1, batch=data.batch) # tensors of indices representing the graph
edge_vec = data.pos[edge_src] - data.pos[edge_dst]
edge_sh = o3.spherical_harmonics(
l=self.irreps_sh,
x=edge_vec,
normalize=False, # here we don't normalize otherwise it would not be a polynomial
normalization="component",
)
# For each node, the initial features are the sum of the spherical harmonics of the neighbors
node_features = scatter(edge_sh, edge_dst, dim=0).div(num_neighbors**0.5)
# For each edge, tensor product the features on the source node with the spherical harmonics
edge_features = self.tp1(node_features[edge_src], edge_sh)
node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5)
edge_features = self.tp2(node_features[edge_src], edge_sh)
node_features = scatter(edge_features, edge_dst, dim=0).div(num_neighbors**0.5)
# For each graph, all the node's features are summed
return scatter(node_features, data.batch, dim=0).div(num_nodes**0.5)
training
f = InvariantPolynomial()
optim = torch.optim.Adam(f.parameters(), lr=1e-2)
# == Train ==
for step in range(200):
pred = f(data)
loss = (pred - labels).pow(2).sum()
optim.zero_grad()
loss.backward()
optim.step()
if step % 10 == 0:
accuracy = pred.round().eq(labels).all(dim=1).double().mean(dim=0).item()
print(f"epoch {step:5d} | loss {loss:<10.1f} | {100 * accuracy:5.1f}% accuracy")
Full code here