math

e3nn.math.direct_sum(*matrices)[source]

Direct sum of matrices, put them in the diagonal

e3nn.math.orthonormalize()[source]

orthonomalize vectors

Parameters
  • original (torch.Tensor) – list of the original vectors \(x\)

  • eps (float) – a small number

Returns

  • final (torch.Tensor) – list of orthonomalized vectors \(y\)

  • matrix (torch.Tensor) – the matrix \(A\) such that \(y = A x\)

e3nn.math.complete_basis()[source]
e3nn.math.soft_one_hot_linspace(x: torch.Tensor, start, end, number, basis=None, cutoff=None)[source]

Projection on a basis of functions

Returns a set of \(\{y_i(x)\}_{i=1}^N\),

\[y_i(x) = \frac{1}{Z} f_i(x)\]

where \(x\) is the input and \(f_i\) is the ith basis function. \(Z\) is a constant defined (if possible) such that,

\[\langle \sum_{i=1}^N y_i(x)^2 \rangle_x \approx 1\]

See the last plot below. Note that bessel basis cannot be normalized.

Parameters
  • x (torch.Tensor) – tensor of shape \((...)\)

  • start (float) – minimum value span by the basis

  • end (float) – maximum value span by the basis

  • number (int) – number of basis functions \(N\)

  • basis ({'gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel'}) – choice of basis family; note that due to the \(1/x\) term, bessel basis does not satisfy the normalization of other basis choices

  • cutoff (bool) – if cutoff=True then for all \(x\) outside of the interval defined by (start, end), \(\forall i, \; f_i(x) \approx 0\)

Returns

tensor of shape \((..., N)\)

Return type

torch.Tensor

Examples

bases = ['gaussian', 'cosine', 'smooth_finite', 'fourier', 'bessel']
x = torch.linspace(-1.0, 2.0, 100)
fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True)

for axs, b in zip(axss, bases):
    for ax, c in zip(axs, [True, False]):
        plt.sca(ax)
        plt.plot(x, soft_one_hot_linspace(x, -0.5, 1.5, number=4, basis=b, cutoff=c))
        plt.plot([-0.5]*2, [-2, 2], 'k-.')
        plt.plot([1.5]*2, [-2, 2], 'k-.')
        plt.title(f"{b}" + (" with cutoff" if c else ""))

plt.ylim(-1, 1.5)
plt.tight_layout()
../../_images/math_2_0.png
fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True)

for axs, b in zip(axss, bases):
    for ax, c in zip(axs, [True, False]):
        plt.sca(ax)
        plt.plot(x, soft_one_hot_linspace(x, -0.5, 1.5, number=4, basis=b, cutoff=c).pow(2).sum(1))
        plt.plot([-0.5]*2, [-2, 2], 'k-.')
        plt.plot([1.5]*2, [-2, 2], 'k-.')
        plt.title(f"{b}" + (" with cutoff" if c else ""))

plt.ylim(0, 2)
plt.tight_layout()
../../_images/math_3_0.png
e3nn.math.soft_unit_step(x)[source]

smooth \(C^\infty\) version of the unit step function

\[x \mapsto \theta(x) e^{-1/x}\]
Parameters

x (torch.Tensor) – tensor of shape \((...)\)

Returns

tensor of shape \((...)\)

Return type

torch.Tensor

Examples

x = torch.linspace(-1.0, 10.0, 1000)
plt.plot(x, soft_unit_step(x));
../../_images/math_5_0.png