JIT - wrappers for TorchScript
Functions:
|
Recursively compile a module and all submodules according to their decorators. |
|
Decorator to set the compile mode of a module. |
Context manager that disables the legacy PyTorch code generation used in e3nn. |
|
|
Get the compilation mode of a module. |
|
Get random tracing inputs for |
|
Function transform that prepares a e3nn module for torch.compile |
|
Script a module. |
|
Recursively searches for registered modules to simplify with |
|
Decorator to register a module for symbolic simplification |
|
Trace a module. |
|
Trace a module. |
- e3nn.util.jit.compile(mod: Module, n_trace_checks: int = 1, script_options: dict = None, trace_options: dict = None, in_place: bool = True, recurse: bool = True)[source]
Recursively compile a module and all submodules according to their decorators.
(Sub)modules without decorators will be unaffected.
- Parameters:
mod (torch.nn.Module) – The module to compile. The module will have its submodules compiled replaced in-place.
n_trace_checks (int, default = 1) – How many random example inputs to generate when tracing a module. Must be at least one in order to have a tracing input. Extra example inputs will be pased to
torch.jit.trace
to confirm that the traced copmute graph doesn’t change.script_options (dict, default = {}) – Extra kwargs for
torch.jit.script
.trace_options (dict, default = {}) – Extra kwargs for
torch.jit.trace
.in_place (bool, default True) – Whether to insert the recursively compiled submodules in-place, or do a deepcopy first.
recurse (bool, default True) – Whether to recurse through the module’s children before passing the parent to TorchScript
- Return type:
Returns the compiled module.
- e3nn.util.jit.compile_mode(mode: str)[source]
Decorator to set the compile mode of a module.
- Parameters:
mode (str) – ‘script’, ‘trace’, or None
- e3nn.util.jit.disable_e3nn_codegen()[source]
Context manager that disables the legacy PyTorch code generation used in e3nn.
- e3nn.util.jit.get_compile_mode(mod: Module) str [source]
Get the compilation mode of a module.
- Parameters:
mod (torch.nn.Module)
- Return type:
‘script’, ‘trace’, or None if the module was not decorated with @compile_mode
- e3nn.util.jit.get_tracing_inputs(mod: Module, n: int = 1, device: device | None = None, dtype: dtype | None = None)[source]
Get random tracing inputs for
mod
.First checks if
mod
has a_make_tracing_inputs
method. If so, calls it withn
as the single argument and returns its results.Otherwise, attempts to infer the input signature of the module using
e3nn.util._argtools._get_io_irreps
.- Parameters:
mod (torch.nn.Module)
n (int, default = 1) – A hint for how many inputs are wanted. Usually n will be returned, but modules don’t necessarily have to.
device (torch.device) – The device to do tracing on. If
None
(default), will be guessed.dtype (torch.dtype) – The dtype to trace with. If
None
(default), will be guessed.
- Returns:
Tracing inputs in the format of
torch.jit.trace_module
: dicts mapping method names like'forward'
to tuples of arguments.- Return type:
- e3nn.util.jit.prepare(func: Callable[[...], Module], allow_autograd: bool = True) Callable[[...], Module] [source]
Function transform that prepares a e3nn module for torch.compile
- Parameters:
func (ModuleFactory) – A function that creates an nn.Module
allow_autograd (bool, optional) – Force inductor compiler to inline call to
torch.autograd.grad
. Defaults to True.
- Returns:
Decorated function that creates a torch.compile compatible module
- Return type:
ModuleFactory
- e3nn.util.jit.script(mod: Module, in_place: bool = True)[source]
Script a module.
Like
torch.jit.script
, but first recursively compilesmod
using :func:compile
.- Parameters:
mod (torch.nn.Module)
- Return type:
Scripted module.
- e3nn.util.jit.simplify(module: Module) Module [source]
Recursively searches for registered modules to simplify with
torch.fx.symbolic_trace
to support compiling with the PyTorch Dynamo compiler.Modules are registered with the
simplify_if_compile
decorator and- Parameters:
module (nn.Module) – the module to simplify
- Returns:
the simplified module
- Return type:
nn.Module
- e3nn.util.jit.simplify_if_compile(module: Module) Module [source]
Decorator to register a module for symbolic simplification
The decorated module will be simplifed using
torch.fx.symbolic_trace
. This constrains the module to not have any dynamic control flow, see:https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing
- Parameters:
module (nn.Module) – the module to register
- Returns:
registered module
- Return type:
nn.Module
- e3nn.util.jit.trace(mod: Module, example_inputs: tuple = None, check_inputs: list = None, in_place: bool = True)[source]
Trace a module.
Identical signature to
torch.jit.trace
, but first recursively compilesmod
using :func:compile
.- Parameters:
mod (torch.nn.Module)
example_inputs (tuple)
- Return type:
Traced module.
- e3nn.util.jit.trace_module(mod: Module, inputs: dict = None, check_inputs: list = None, in_place: bool = True)[source]
Trace a module.
Identical signature to
torch.jit.trace_module
, but first recursively compilesmod
usingcompile
.- Parameters:
mod (torch.nn.Module)
inputs (dict)
- Return type:
Traced module.