Spherical Tensor

There exists 4 types of function on the sphere depending on how the parity affects it. The representation of the coefficients are affected by this choice:

import torch
from e3nn.io import SphericalTensor

print(SphericalTensor(lmax=2, p_val=1, p_arg=1))
print(SphericalTensor(lmax=2, p_val=1, p_arg=-1))
print(SphericalTensor(lmax=2, p_val=-1, p_arg=1))
print(SphericalTensor(lmax=2, p_val=-1, p_arg=-1))
1x0e+1x1e+1x2e
1x0e+1x1o+1x2e
1x0o+1x1o+1x2o
1x0o+1x1e+1x2o
import plotly.graph_objects as go

def plot(traces):
    traces = [go.Surface(**d) for d in traces]
    fig = go.Figure(data=traces)
    fig.show()

In the following graph we show the four possible behavior under parity for a function on the sphere.

  1. This first ball shows \(f(x)\) unaffected by the parity

  2. Then p_val=1 but p_arg=-1 so we see the signal flipped over the sphere but the colors are unchanged

  3. For p_val=-1 and p_arg=1 only the value of the signal flips its sign

  4. For p_val=-1 and p_arg=-1 both in the same time, the signal flips over the sphere and the value flip its sign

lmax = 1
x = torch.tensor([0.8] + [0.0, 0.0, 1.0])

parity = -torch.eye(3)

x = torch.stack([
    SphericalTensor(lmax, p_val, p_arg).D_from_matrix(parity) @ x
    for p_val in [+1, -1]
    for p_arg in [+1, -1]
])
centers = torch.tensor([
    [-3.0, 0.0, 0.0],
    [-1.0, 0.0, 0.0],
    [1.0, 0.0, 0.0],
    [3.0, 0.0, 0.0],
])

st = SphericalTensor(lmax, 1, 1)  # p_val and p_arg set arbitrarily here
plot(st.plotly_surface(x, centers=centers, radius=False))
class e3nn.io.SphericalTensor(lmax, p_val, p_arg)[source]

Bases: e3nn.o3._irreps.Irreps

representation of a signal on the sphere

A SphericalTensor contains the coefficients \(A^l\) of a function \(f\) defined on the sphere

\[f(x) = \sum_{l=0}^{l_\mathrm{max}} A^l \cdot Y^l(x)\]

The way this function is transformed by parity \(f \longrightarrow P f\) is described by the two parameters \(p_v\) and \(p_a\)

\[ \begin{align}\begin{aligned}(P f)(x) &= p_v f(p_a x)\\&= \sum_{l=0}^{l_\mathrm{max}} p_v p_a^l A^l \cdot Y^l(x)\end{aligned}\end{align} \]
Parameters
  • lmax (int) – \(l_\mathrm{max}\)

  • p_val ({+1, -1}) – \(p_v\)

  • p_arg ({+1, -1}) – \(p_a\)

Examples

>>> SphericalTensor(3, 1, 1)
1x0e+1x1e+1x2e+1x3e
>>> SphericalTensor(3, 1, -1)
1x0e+1x1o+1x2e+1x3o

Methods:

find_peaks(signal[, res])

Locate peaks on the sphere

from_samples_on_s2(positions, values[, res])

Convert a set of position on the sphere and values into a spherical tensor

norms(signal)

The norms of each l component

plot(signal[, center, res, radius, relu, ...])

Create surface in order to make a plot

plotly_surface(signals[, centers, res, ...])

Create traces for plotly

signal_on_grid(signal[, res, normalization])

Evaluate the signal on a grid on the sphere

signal_xyz(signal, r)

Evaluate the signal on given points on the sphere

sum_of_diracs(positions, values)

Sum (almost-) dirac deltas

with_peaks_at(vectors[, values])

Create a spherical tensor with peaks

find_peaks(signal, res=100)[source]

Locate peaks on the sphere

Examples

>>> s = SphericalTensor(4, 1, -1)
>>> pos = torch.tensor([
...     [4.0, 0.0, 4.0],
...     [0.0, 5.0, 0.0],
... ])
>>> x = s.with_peaks_at(pos)
>>> pos, val = s.find_peaks(x)
>>> pos[val > 4.0].mul(10).round().abs()
tensor([[ 7.,  0.,  7.],
        [ 0., 10.,  0.]])
>>> val[val > 4.0].mul(10).round().abs()
tensor([57., 50.])
from_samples_on_s2(positions: torch.Tensor, values: torch.Tensor, res=100) torch.Tensor[source]

Convert a set of position on the sphere and values into a spherical tensor

Parameters
Returns

tensor of shape (..., self.dim)

Return type

torch.Tensor

Examples

>>> s = SphericalTensor(2, 1, 1)
>>> pos = torch.tensor([
...     [
...         [0.0, 0.0, 1.0],
...         [0.0, 0.0, -1.0],
...     ],
...     [
...         [0.0, 1.0, 0.0],
...         [0.0, -1.0, 0.0],
...     ],
... ], dtype=torch.float64)
>>> val = torch.tensor([
...     [
...         1.0,
...         -1.0,
...     ],
...     [
...         1.0,
...         -1.0,
...     ],
... ], dtype=torch.float64)
>>> s.from_samples_on_s2(pos, val, res=200).long()
tensor([[0, 0, 0, 3, 0, 0, 0, 0, 0],
        [0, 0, 3, 0, 0, 0, 0, 0, 0]])
>>> pos = torch.empty(2, 0, 10, 3)
>>> val = torch.empty(2, 0, 10)
>>> s.from_samples_on_s2(pos, val)
tensor([], size=(2, 0, 9))
norms(signal)[source]

The norms of each l component

Parameters

signal (torch.Tensor) – tensor of shape (..., dim)

Returns

tensor of shape (..., lmax+1)

Return type

torch.Tensor

Examples

Examples

>>> s = SphericalTensor(1, 1, -1)
>>> s.norms(torch.tensor([1.5, 0.0, 3.0, 4.0]))
tensor([1.5000, 5.0000])
plot(signal, center=None, res=100, radius=True, relu=False, normalization='integral')[source]

Create surface in order to make a plot

plotly_surface(signals, centers=None, res=100, radius=True, relu=False, normalization='integral')[source]

Create traces for plotly

Examples

>>> import plotly.graph_objects as go
>>> x = SphericalTensor(4, +1, +1)
>>> traces = x.plotly_surface(x.randn(-1))
>>> traces = [go.Surface(**d) for d in traces]
>>> fig = go.Figure(data=traces)
signal_on_grid(signal, res=100, normalization='integral')[source]

Evaluate the signal on a grid on the sphere

signal_xyz(signal, r)[source]

Evaluate the signal on given points on the sphere

\[f(\vec x / \|\vec x\|)\]
Parameters
Returns

tensor of shape (*A, *B)

Return type

torch.Tensor

Examples

>>> s = SphericalTensor(3, 1, -1)
>>> s.signal_xyz(s.randn(2, 1, 3, -1), torch.randn(2, 4, 3)).shape
torch.Size([2, 1, 3, 2, 4])
sum_of_diracs(positions: torch.Tensor, values: torch.Tensor) torch.Tensor[source]

Sum (almost-) dirac deltas

\[f(x) = \sum_i v_i \delta^L(\vec r_i)\]

where \(\delta^L\) is the apporximation of a dirac delta.

Parameters
  • positions (torch.Tensor) – \(\vec r_i\) tensor of shape (..., N, 3)

  • values (torch.Tensor) – \(v_i\) tensor of shape (..., N)

Returns

tensor of shape (..., self.dim)

Return type

torch.Tensor

Examples

>>> s = SphericalTensor(7, 1, -1)
>>> pos = torch.tensor([
...     [1.0, 0.0, 0.0],
...     [0.0, 1.0, 0.0],
... ])
>>> val = torch.tensor([
...     -1.0,
...     1.0,
... ])
>>> x = s.sum_of_diracs(pos, val)
>>> s.signal_xyz(x, torch.eye(3)).mul(10.0).round()
tensor([-10.,  10.,  -0.])
>>> s.sum_of_diracs(torch.empty(1, 0, 2, 3), torch.empty(2, 0, 1)).shape
torch.Size([2, 0, 64])
>>> s.sum_of_diracs(torch.randn(1, 3, 2, 3), torch.randn(2, 1, 1)).shape
torch.Size([2, 3, 64])
with_peaks_at(vectors, values=None)[source]

Create a spherical tensor with peaks

The peaks are located in \(\vec r_i\) and have amplitude \(\|\vec r_i \|\)

Parameters
  • vectors (torch.Tensor) – \(\vec r_i\) tensor of shape (N, 3)

  • values (torch.Tensor, optional) – value on the peak, tensor of shape (N)

Returns

tensor of shape (self.dim,)

Return type

torch.Tensor

Examples

>>> s = SphericalTensor(4, 1, -1)
>>> pos = torch.tensor([
...     [1.0, 0.0, 0.0],
...     [3.0, 4.0, 0.0],
... ])
>>> x = s.with_peaks_at(pos)
>>> s.signal_xyz(x, pos).long()
tensor([1, 5])
>>> val = torch.tensor([
...     -1.5,
...     2.0,
... ])
>>> x = s.with_peaks_at(pos, val)
>>> s.signal_xyz(x, pos)
tensor([-1.5000,  2.0000])