TorchScript JIT Support

PyTorch provides two ways to compile code into TorchScript: tracing and scripting. Tracing follows the tensor operations on an example input, allowing complex Python control flow if that control flow does not depend on the data itself. Scripting compiles a subset of Python directly into TorchScript, allowing data-dependent control flow but only limited Python features.

This is a problem for e3nn, where many modules — such as e3nn.o3.TensorProduct — use significant Python control flow based on e3nn.o3.Irreps as well as features like inheritance that are incompatible with scripting. Other modules like e3nn.nn.Gate, however, contain important but simple data-dependent control flow. Thus e3nn.nn.Gate needs to be scripted, even though it contains a e3nn.o3.TensorProduct that has to be traced.

To hide this complexity from the user and prevent difficult-to-understand errors, e3nn implements a wrapper for torch.jite3nn.util.jit — that recursively and automatically compiles submodules according to directions they provide. Using the @compile_mode decorator, modules can indicate whether they should be scripted, traced, or left alone.

Simple Example: Scripting

We define a simple module that includes data-dependent control flow:

import torch
from e3nn.o3 import Norm, Irreps

class MyModule(torch.nn.Module):
    def __init__(self, irreps_in) -> None:
        super().__init__()
        self.norm = Norm(irreps_in)

    def forward(self, x):
        norm = self.norm(x)
        if torch.any(norm > 7.):
            return norm
        else:
            return norm * 0.5

irreps = Irreps("2x0e + 1x1o")
mod = MyModule(irreps)

To compile it to TorchScript, we can try to use torch.jit.script:

try:
    mod_script = torch.jit.script(mod)
except:
    print("Compilation failed!")

This fails because Norm is a subclass of e3nn.o3.TensorProduct and TorchScript doesn’t support inheritance. If we use e3nn.util.jit.script, on the other hand, it works:

from e3nn.util.jit import script, trace
mod_script = script(mod)

Internally, e3nn.util.jit.script recurses through the submodules of mod, compiling each in accordance with its @e3nn.util.jit.compile_mode decorator if it has one. In particular, Norm and other e3nn.o3.TensorProduct s are marked with @compile_mode('trace'), so e3nn.util.jit constructs an example input for mod.norm, traces it, and replaces it with the traced TorchScript module. Then when the parent module mod is compiled inside e3nn.util.jit.script with torch.jit.script, the submodule mod.norm has already been compiled and is integrated without issue.

As expected, the scripted module and the original give the same results:

x = irreps.randn(2, -1)
assert torch.allclose(mod(x), mod_script(x))

Mixing Tracing and Scripting

Say we define:

from e3nn.util.jit import compile_mode

@compile_mode('script')
class MyModule(torch.nn.Module):
    def __init__(self, irreps_in) -> None:
        super().__init__()
        self.norm = Norm(irreps_in)

    def forward(self, x):
        norm = self.norm(x)
        for row in norm:
            if torch.any(row > 0.1):
                return row
        return norm

class AnotherModule(torch.nn.Module):
    def __init__(self, irreps_in) -> None:
        super().__init__()
        self.mymod = MyModule(irreps_in)

    def forward(self, x):
        return self.mymod(x) + 3.

And trace an instance of AnotherModule using e3nn.util.jit.trace:

mod2 = AnotherModule(irreps)
example_inputs = (irreps.randn(3, -1),)
mod2_traced = trace(
    mod2,
    example_inputs
)

Note that we marked MyModule with @compile_mode('script') because it contains control flow, and that the control flow is preserved even when called from the traced AnotherModule:

print(mod2_traced(torch.zeros(2, irreps.dim)))
print(mod2_traced(irreps.randn(3, -1)))
tensor([[3., 3., 3.],
        [3., 3., 3.]])
tensor([3.3305, 3.9519, 4.8591])

We can confirm that the submodule mymod was compiled as a script, but that mod2 was traced:

print(type(mod2_traced))
print(type(mod2_traced.mymod))
<class 'torch.jit._trace.TopLevelTracedModule'>
<class 'torch.jit._script.RecursiveScriptModule'>

Customizing Tracing Inputs

Submodules can also be compiled automatically using tracing if they are marked with @compile_mode('trace'). When submodules are compiled by tracing it must be possible to generate plausible input examples on the fly.

These example inputs can be generated automatically based on the irreps_in of the module (the specifics are the same as for assert_equivariant). If this is not possible or would yield incorrect results, a module can define a _make_tracing_inputs method that generates example inputs of correct shape and type.

@compile_mode('trace')
class TracingModule(torch.nn.Module):
    def forward(self, x: torch.Tensor, indexes: torch.LongTensor):
        return x[indexes].sum()

    # Because this module has no `irreps_in`, and because
    # `irreps_in` can't describe indexes, since it's a LongTensor,
    # we impliment _make_tracing_inputs
    def _make_tracing_inputs(self, n: int):
        import random
        # The compiler asks for n example inputs ---
        # this is only a suggestion, the only requirement
        # is that at least one be returned.
        return [
            {
                'forward': (
                    torch.randn(5, random.randint(1, 3)),
                    torch.arange(3)
                )
            }
            for _ in range(n)
        ]

To recursively compile this module and its submodules in accordance with their @compile_mode``s, we can use ``e3nn.util.jit.compile directly. This can be useful if the module you are compiling is annotated with @compile_mode and you don’t want to override that annotation by using trace or script:

from e3nn.util.jit import compile
mod3 = TracingModule()
mod3_traced = compile(mod3)
print(type(mod3_traced))
<class 'torch.jit._trace.TopLevelTracedModule'>

Deciding between 'script' and 'trace'

The easiest way to decide on a compile mode for your module is to try both. Tracing will usually generate warnings if it encounters dynamic control flow that it cannot fully capture, and scripting will raise compiler errors for features it does not support.

In general, any module that uses inheritance or control flow based on e3nn.o3.Irreps in forward() will have to be traced.

Testing

A helper function is provided to unit test that auto-JITable modules (those annotated with @compile_mode) can be compiled:

from e3nn.util.test import assert_auto_jitable
assert_auto_jitable(mod2)
AnotherModule(
  original_name=AnotherModule
  (mymod): RecursiveScriptModule(
    original_name=MyModule
    (norm): Norm(
      original_name=Norm
      (tp): RecursiveScriptModule(
        original_name=TensorProduct
        (_compiled_main_left_right): RecursiveScriptModule(original_name=GraphModule)
        (_compiled_main_right): RecursiveScriptModule(original_name=tp_forward)
      )
    )
  )
)

By default, assert_auto_jitable will test traced modules to confirm that they reject input shapes that are likely incorrect. Specifically, it changes x.shape[-1] on the assumption that the final dimension is a network architecture constant. If this heuristic is wrong for your module (like it is for TracedModule above), it can be disabled:

assert_auto_jitable(mod3, strict_shapes=False)
TracingModule(original_name=TracingModule)

Compile mode "unsupported"

Sometimes you may write modules that use features unsupported by TorchScript regardless of whether you trace or script. To avoid cryptic errors from TorchScript if someone tries to compile a model containing such a module, the module can be marked with @compile_mode("unsupported"):

@compile_mode('unsupported')
class ChildMod(torch.nn.Module):
    pass

class Supermod(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.child = ChildMod()

mod = Supermod()
script(mod)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[13], line 11
      8         self.child = ChildMod()
     10 mod = Supermod()
---> 11 script(mod)

File ~/checkouts/readthedocs.org/user_builds/e3nn/envs/latest/lib/python3.8/site-packages/e3nn/util/jit.py:266, in script(mod, in_place)
    263 setattr(mod, _E3NN_COMPILE_MODE, "script")
    265 # Compile
--> 266 out = compile(mod, in_place=in_place)
    268 # Restore old values, if we had them
    269 if old_mode is not None:

File ~/checkouts/readthedocs.org/user_builds/e3nn/envs/latest/lib/python3.8/site-packages/e3nn/util/jit.py:101, in compile(mod, n_trace_checks, script_options, trace_options, in_place)
     95 # == recurse to children ==
     96 # This allows us to trace compile submodules of modules we are going to script
     97 for submod_name, submod in mod.named_children():
     98     setattr(
     99         mod,
    100         submod_name,
--> 101         compile(
    102             submod,
    103             n_trace_checks=n_trace_checks,
    104             script_options=script_options,
    105             trace_options=trace_options,
    106             in_place=True,  # since we deepcopied the module above, we can do inplace
    107         ),
    108     )
    109 # == Compile this module now ==
    110 if mode == "script":

File ~/checkouts/readthedocs.org/user_builds/e3nn/envs/latest/lib/python3.8/site-packages/e3nn/util/jit.py:89, in compile(mod, n_trace_checks, script_options, trace_options, in_place)
     87 mode = get_compile_mode(mod)
     88 if mode == "unsupported":
---> 89     raise NotImplementedError(f"{type(mod).__name__} does not support TorchScript compilation")
     91 if not in_place:
     92     mod = copy.deepcopy(mod)

NotImplementedError: ChildMod does not support TorchScript compilation