JIT - wrappers for TorchScript
Functions:
|
Recursively compile a module and all submodules according to their decorators. |
|
Decorator to set the compile mode of a module. |
Context manager that disables the legacy PyTorch code generation used in e3nn. |
|
|
Get the compilation mode of a module. |
|
Get random tracing inputs for |
|
Script a module. |
|
Trace a module. |
|
Trace a module. |
- e3nn.util.jit.compile(mod: Module, n_trace_checks: int = 1, script_options: dict = None, trace_options: dict = None, in_place: bool = True, recurse: bool = True)[source]
Recursively compile a module and all submodules according to their decorators.
(Sub)modules without decorators will be unaffected.
- Parameters:
mod (torch.nn.Module) – The module to compile. The module will have its submodules compiled replaced in-place.
n_trace_checks (int, default = 1) – How many random example inputs to generate when tracing a module. Must be at least one in order to have a tracing input. Extra example inputs will be pased to
torch.jit.traceto confirm that the traced copmute graph doesn’t change.script_options (dict, default = {}) – Extra kwargs for
torch.jit.script.trace_options (dict, default = {}) – Extra kwargs for
torch.jit.trace.in_place (bool, default True) – Whether to insert the recursively compiled submodules in-place, or do a deepcopy first.
recurse (bool, default True) – Whether to recurse through the module’s children before passing the parent to TorchScript
- Return type:
Returns the compiled module.
- e3nn.util.jit.compile_mode(mode: str)[source]
Decorator to set the compile mode of a module.
- Parameters:
mode (str) – ‘script’, ‘trace’, or None
- e3nn.util.jit.disable_e3nn_codegen()[source]
Context manager that disables the legacy PyTorch code generation used in e3nn.
- e3nn.util.jit.get_compile_mode(mod: Module) str[source]
Get the compilation mode of a module.
- Parameters:
mod (torch.nn.Module)
- Return type:
‘script’, ‘trace’, or None if the module was not decorated with @compile_mode
- e3nn.util.jit.get_tracing_inputs(mod: Module, n: int = 1, device: device | None = None, dtype: dtype | None = None)[source]
Get random tracing inputs for
mod.First checks if
modhas a_make_tracing_inputsmethod. If so, calls it withnas the single argument and returns its results.Otherwise, attempts to infer the input signature of the module using
e3nn.util._argtools._get_io_irreps.- Parameters:
mod (torch.nn.Module)
n (int, default = 1) – A hint for how many inputs are wanted. Usually n will be returned, but modules don’t necessarily have to.
device (torch.device) – The device to do tracing on. If
None(default), will be guessed.dtype (torch.dtype) – The dtype to trace with. If
None(default), will be guessed.
- Returns:
Tracing inputs in the format of
torch.jit.trace_module: dicts mapping method names like'forward'to tuples of arguments.- Return type:
- e3nn.util.jit.script(mod: Module, in_place: bool = True)[source]
Script a module.
Like
torch.jit.script, but first recursively compilesmodusing :func:compile.- Parameters:
mod (torch.nn.Module)
- Return type:
Scripted module.
- e3nn.util.jit.trace(mod: Module, example_inputs: tuple = None, check_inputs: list = None, in_place: bool = True)[source]
Trace a module.
Identical signature to
torch.jit.trace, but first recursively compilesmodusing :func:compile.- Parameters:
mod (torch.nn.Module)
example_inputs (tuple)
- Return type:
Traced module.
- e3nn.util.jit.trace_module(mod: Module, inputs: dict = None, check_inputs: list = None, in_place: bool = True)[source]
Trace a module.
Identical signature to
torch.jit.trace_module, but first recursively compilesmodusingcompile.- Parameters:
mod (torch.nn.Module)
inputs (dict)
- Return type:
Traced module.