JIT - wrappers for TorchScript

Functions:

compile(mod[, n_trace_checks, ...])

Recursively compile a module and all submodules according to their decorators.

compile_mode(mode)

Decorator to set the compile mode of a module.

disable_e3nn_codegen()

Context manager that disables the legacy PyTorch code generation used in e3nn.

get_compile_mode(mod)

Get the compilation mode of a module.

get_tracing_inputs(mod[, n, device, dtype])

Get random tracing inputs for mod.

prepare(func[, allow_autograd])

Function transform that prepares a e3nn module for torch.compile

script(mod[, in_place])

Script a module.

simplify(module)

Recursively searches for registered modules to simplify with torch.fx.symbolic_trace to support compiling with the PyTorch Dynamo compiler.

simplify_if_compile(module)

Decorator to register a module for symbolic simplification

trace(mod[, example_inputs, check_inputs, ...])

Trace a module.

trace_module(mod[, inputs, check_inputs, ...])

Trace a module.

e3nn.util.jit.compile(mod: Module, n_trace_checks: int = 1, script_options: dict = None, trace_options: dict = None, in_place: bool = True, recurse: bool = True)[source]

Recursively compile a module and all submodules according to their decorators.

(Sub)modules without decorators will be unaffected.

Parameters:
  • mod (torch.nn.Module) – The module to compile. The module will have its submodules compiled replaced in-place.

  • n_trace_checks (int, default = 1) – How many random example inputs to generate when tracing a module. Must be at least one in order to have a tracing input. Extra example inputs will be pased to torch.jit.trace to confirm that the traced copmute graph doesn’t change.

  • script_options (dict, default = {}) – Extra kwargs for torch.jit.script.

  • trace_options (dict, default = {}) – Extra kwargs for torch.jit.trace.

  • in_place (bool, default True) – Whether to insert the recursively compiled submodules in-place, or do a deepcopy first.

  • recurse (bool, default True) – Whether to recurse through the module’s children before passing the parent to TorchScript

Return type:

Returns the compiled module.

e3nn.util.jit.compile_mode(mode: str)[source]

Decorator to set the compile mode of a module.

Parameters:

mode (str) – ‘script’, ‘trace’, or None

e3nn.util.jit.disable_e3nn_codegen()[source]

Context manager that disables the legacy PyTorch code generation used in e3nn.

e3nn.util.jit.get_compile_mode(mod: Module) str[source]

Get the compilation mode of a module.

Parameters:

mod (torch.nn.Module)

Return type:

‘script’, ‘trace’, or None if the module was not decorated with @compile_mode

e3nn.util.jit.get_tracing_inputs(mod: Module, n: int = 1, device: device | None = None, dtype: dtype | None = None)[source]

Get random tracing inputs for mod.

First checks if mod has a _make_tracing_inputs method. If so, calls it with n as the single argument and returns its results.

Otherwise, attempts to infer the input signature of the module using e3nn.util._argtools._get_io_irreps.

Parameters:
  • mod (torch.nn.Module)

  • n (int, default = 1) – A hint for how many inputs are wanted. Usually n will be returned, but modules don’t necessarily have to.

  • device (torch.device) – The device to do tracing on. If None (default), will be guessed.

  • dtype (torch.dtype) – The dtype to trace with. If None (default), will be guessed.

Returns:

Tracing inputs in the format of torch.jit.trace_module: dicts mapping method names like 'forward' to tuples of arguments.

Return type:

list of dict

e3nn.util.jit.prepare(func: Callable[[...], Module], allow_autograd: bool = True) Callable[[...], Module][source]

Function transform that prepares a e3nn module for torch.compile

Parameters:
  • func (ModuleFactory) – A function that creates an nn.Module

  • allow_autograd (bool, optional) – Force inductor compiler to inline call to torch.autograd.grad. Defaults to True.

Returns:

Decorated function that creates a torch.compile compatible module

Return type:

ModuleFactory

e3nn.util.jit.script(mod: Module, in_place: bool = True)[source]

Script a module.

Like torch.jit.script, but first recursively compiles mod using :func:compile.

Parameters:

mod (torch.nn.Module)

Return type:

Scripted module.

e3nn.util.jit.simplify(module: Module) Module[source]

Recursively searches for registered modules to simplify with torch.fx.symbolic_trace to support compiling with the PyTorch Dynamo compiler.

Modules are registered with the simplify_if_compile decorator and

Parameters:

module (nn.Module) – the module to simplify

Returns:

the simplified module

Return type:

nn.Module

e3nn.util.jit.simplify_if_compile(module: Module) Module[source]

Decorator to register a module for symbolic simplification

The decorated module will be simplifed using torch.fx.symbolic_trace. This constrains the module to not have any dynamic control flow, see:

https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing

Parameters:

module (nn.Module) – the module to register

Returns:

registered module

Return type:

nn.Module

e3nn.util.jit.trace(mod: Module, example_inputs: tuple = None, check_inputs: list = None, in_place: bool = True)[source]

Trace a module.

Identical signature to torch.jit.trace, but first recursively compiles mod using :func:compile.

Parameters:
Return type:

Traced module.

e3nn.util.jit.trace_module(mod: Module, inputs: dict = None, check_inputs: list = None, in_place: bool = True)[source]

Trace a module.

Identical signature to torch.jit.trace_module, but first recursively compiles mod using compile.

Parameters:
Return type:

Traced module.