Spherical Tensor
There exists 4 types of function on the sphere depending on how the parity affects it. The representation of the coefficients are affected by this choice:
import torch
from e3nn.io import SphericalTensor
print(SphericalTensor(lmax=2, p_val=1, p_arg=1))
print(SphericalTensor(lmax=2, p_val=1, p_arg=-1))
print(SphericalTensor(lmax=2, p_val=-1, p_arg=1))
print(SphericalTensor(lmax=2, p_val=-1, p_arg=-1))
1x0e+1x1e+1x2e
1x0e+1x1o+1x2e
1x0o+1x1o+1x2o
1x0o+1x1e+1x2o
import plotly.graph_objects as go
def plot(traces):
traces = [go.Surface(**d) for d in traces]
fig = go.Figure(data=traces)
fig.show()
In the following graph we show the four possible behavior under parity for a function on the sphere.
This first ball shows \(f(x)\) unaffected by the parity
Then
p_val=1
butp_arg=-1
so we see the signal flipped over the sphere but the colors are unchangedFor
p_val=-1
andp_arg=1
only the value of the signal flips its signFor
p_val=-1
andp_arg=-1
both in the same time, the signal flips over the sphere and the value flip its sign
lmax = 1
x = torch.tensor([0.8] + [0.0, 0.0, 1.0])
parity = -torch.eye(3)
x = torch.stack([
SphericalTensor(lmax, p_val, p_arg).D_from_matrix(parity) @ x
for p_val in [+1, -1]
for p_arg in [+1, -1]
])
centers = torch.tensor([
[-3.0, 0.0, 0.0],
[-1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[3.0, 0.0, 0.0],
])
st = SphericalTensor(lmax, 1, 1) # p_val and p_arg set arbitrarily here
plot(st.plotly_surface(x, centers=centers, radius=False))
- class e3nn.io.SphericalTensor(lmax, p_val, p_arg)[source]
Bases:
e3nn.o3._irreps.Irreps
representation of a signal on the sphere
A
SphericalTensor
contains the coefficients \(A^l\) of a function \(f\) defined on the sphere\[f(x) = \sum_{l=0}^{l_\mathrm{max}} A^l \cdot Y^l(x)\]The way this function is transformed by parity \(f \longrightarrow P f\) is described by the two parameters \(p_v\) and \(p_a\)
\[ \begin{align}\begin{aligned}(P f)(x) &= p_v f(p_a x)\\&= \sum_{l=0}^{l_\mathrm{max}} p_v p_a^l A^l \cdot Y^l(x)\end{aligned}\end{align} \]- Parameters
lmax (int) – \(l_\mathrm{max}\)
p_val ({+1, -1}) – \(p_v\)
p_arg ({+1, -1}) – \(p_a\)
Examples
>>> SphericalTensor(3, 1, 1) 1x0e+1x1e+1x2e+1x3e
>>> SphericalTensor(3, 1, -1) 1x0e+1x1o+1x2e+1x3o
Methods:
find_peaks
(signal[, res])Locate peaks on the sphere
from_samples_on_s2
(positions, values[, res])Convert a set of position on the sphere and values into a spherical tensor
norms
(signal)The norms of each l component
plot
(signal[, center, res, radius, relu, ...])Create surface in order to make a plot
plotly_surface
(signals[, centers, res, ...])Create traces for plotly
signal_on_grid
(signal[, res, normalization])Evaluate the signal on a grid on the sphere
signal_xyz
(signal, r)Evaluate the signal on given points on the sphere
sum_of_diracs
(positions, values)Sum (almost-) dirac deltas
with_peaks_at
(vectors[, values])Create a spherical tensor with peaks
- find_peaks(signal, res=100)[source]
Locate peaks on the sphere
Examples
>>> s = SphericalTensor(4, 1, -1) >>> pos = torch.tensor([ ... [4.0, 0.0, 4.0], ... [0.0, 5.0, 0.0], ... ]) >>> x = s.with_peaks_at(pos) >>> pos, val = s.find_peaks(x) >>> pos[val > 4.0].mul(10).round().abs() tensor([[ 7., 0., 7.], [ 0., 10., 0.]]) >>> val[val > 4.0].mul(10).round().abs() tensor([57., 50.])
- from_samples_on_s2(positions: torch.Tensor, values: torch.Tensor, res=100) torch.Tensor [source]
Convert a set of position on the sphere and values into a spherical tensor
- Parameters
positions (
torch.Tensor
) – tensor of shape(..., N, 3)
values (
torch.Tensor
) – tensor of shape(..., N)
- Returns
tensor of shape
(..., self.dim)
- Return type
Examples
>>> s = SphericalTensor(2, 1, 1) >>> pos = torch.tensor([ ... [ ... [0.0, 0.0, 1.0], ... [0.0, 0.0, -1.0], ... ], ... [ ... [0.0, 1.0, 0.0], ... [0.0, -1.0, 0.0], ... ], ... ], dtype=torch.float64) >>> val = torch.tensor([ ... [ ... 1.0, ... -1.0, ... ], ... [ ... 1.0, ... -1.0, ... ], ... ], dtype=torch.float64) >>> s.from_samples_on_s2(pos, val, res=200).long() tensor([[0, 0, 0, 3, 0, 0, 0, 0, 0], [0, 0, 3, 0, 0, 0, 0, 0, 0]])
>>> pos = torch.empty(2, 0, 10, 3) >>> val = torch.empty(2, 0, 10) >>> s.from_samples_on_s2(pos, val) tensor([], size=(2, 0, 9))
- norms(signal)[source]
The norms of each l component
- Parameters
signal (
torch.Tensor
) – tensor of shape(..., dim)
- Returns
tensor of shape
(..., lmax+1)
- Return type
Examples
Examples
>>> s = SphericalTensor(1, 1, -1) >>> s.norms(torch.tensor([1.5, 0.0, 3.0, 4.0])) tensor([1.5000, 5.0000])
- plot(signal, center=None, res=100, radius=True, relu=False, normalization='integral')[source]
Create surface in order to make a plot
- plotly_surface(signals, centers=None, res=100, radius=True, relu=False, normalization='integral')[source]
Create traces for plotly
Examples
>>> import plotly.graph_objects as go >>> x = SphericalTensor(4, +1, +1) >>> traces = x.plotly_surface(x.randn(-1)) >>> traces = [go.Surface(**d) for d in traces] >>> fig = go.Figure(data=traces)
- signal_on_grid(signal, res=100, normalization='integral')[source]
Evaluate the signal on a grid on the sphere
- signal_xyz(signal, r)[source]
Evaluate the signal on given points on the sphere
\[f(\vec x / \|\vec x\|)\]- Parameters
signal (
torch.Tensor
) – tensor of shape(*A, self.dim)
r (
torch.Tensor
) – tensor of shape(*B, 3)
- Returns
tensor of shape
(*A, *B)
- Return type
Examples
>>> s = SphericalTensor(3, 1, -1) >>> s.signal_xyz(s.randn(2, 1, 3, -1), torch.randn(2, 4, 3)).shape torch.Size([2, 1, 3, 2, 4])
- sum_of_diracs(positions: torch.Tensor, values: torch.Tensor) torch.Tensor [source]
Sum (almost-) dirac deltas
\[f(x) = \sum_i v_i \delta^L(\vec r_i)\]where \(\delta^L\) is the apporximation of a dirac delta.
- Parameters
positions (
torch.Tensor
) – \(\vec r_i\) tensor of shape(..., N, 3)
values (
torch.Tensor
) – \(v_i\) tensor of shape(..., N)
- Returns
tensor of shape
(..., self.dim)
- Return type
Examples
>>> s = SphericalTensor(7, 1, -1) >>> pos = torch.tensor([ ... [1.0, 0.0, 0.0], ... [0.0, 1.0, 0.0], ... ]) >>> val = torch.tensor([ ... -1.0, ... 1.0, ... ]) >>> x = s.sum_of_diracs(pos, val) >>> s.signal_xyz(x, torch.eye(3)).mul(10.0).round() tensor([-10., 10., -0.])
>>> s.sum_of_diracs(torch.empty(1, 0, 2, 3), torch.empty(2, 0, 1)).shape torch.Size([2, 0, 64])
>>> s.sum_of_diracs(torch.randn(1, 3, 2, 3), torch.randn(2, 1, 1)).shape torch.Size([2, 3, 64])
- with_peaks_at(vectors, values=None)[source]
Create a spherical tensor with peaks
The peaks are located in \(\vec r_i\) and have amplitude \(\|\vec r_i \|\)
- Parameters
vectors (
torch.Tensor
) – \(\vec r_i\) tensor of shape(N, 3)
values (
torch.Tensor
, optional) – value on the peak, tensor of shape(N)
- Returns
tensor of shape
(self.dim,)
- Return type
Examples
>>> s = SphericalTensor(4, 1, -1) >>> pos = torch.tensor([ ... [1.0, 0.0, 0.0], ... [3.0, 4.0, 0.0], ... ]) >>> x = s.with_peaks_at(pos) >>> s.signal_xyz(x, pos).long() tensor([1, 5])
>>> val = torch.tensor([ ... -1.5, ... 2.0, ... ]) >>> x = s.with_peaks_at(pos, val) >>> s.signal_xyz(x, pos) tensor([-1.5000, 2.0000])