JIT - wrappers for TorchScript
Functions:
|
Recursively compile a module and all submodules according to their decorators. |
|
Decorator to set the compile mode of a module. |
|
Get the compilation mode of a module. |
|
Get random tracing inputs for |
|
Script a module. |
|
Trace a module. |
|
Trace a module. |
- e3nn.util.jit.compile(mod: torch.nn.modules.module.Module, n_trace_checks: int = 1, script_options: Optional[dict] = None, trace_options: Optional[dict] = None, in_place: bool = True)[source]
Recursively compile a module and all submodules according to their decorators.
(Sub)modules without decorators will be unaffected.
- Parameters
mod (torch.nn.Module) – The module to compile. The module will have its submodules compiled replaced in-place.
n_trace_checks (int, default = 1) – How many random example inputs to generate when tracing a module. Must be at least one in order to have a tracing input. Extra example inputs will be pased to
torch.jit.trace
to confirm that the traced copmute graph doesn’t change.script_options (dict, default = {}) – Extra kwargs for
torch.jit.script
.trace_options (dict, default = {}) – Extra kwargs for
torch.jit.trace
.
- Return type
Returns the compiled module.
- e3nn.util.jit.compile_mode(mode: str)[source]
Decorator to set the compile mode of a module.
- Parameters
mode (str) – ‘script’, ‘trace’, or None
- e3nn.util.jit.get_compile_mode(mod: torch.nn.modules.module.Module) str [source]
Get the compilation mode of a module.
- Parameters
mod (torch.nn.Module) –
- Return type
‘script’, ‘trace’, or None if the module was not decorated with @compile_mode
- e3nn.util.jit.get_tracing_inputs(mod: torch.nn.modules.module.Module, n: int = 1, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None)[source]
Get random tracing inputs for
mod
.First checks if
mod
has a_make_tracing_inputs
method. If so, calls it withn
as the single argument and returns its results.Otherwise, attempts to infer the input signature of the module using
e3nn.util._argtools._get_io_irreps
.- Parameters
mod (torch.nn.Module) –
n (int, default = 1) – A hint for how many inputs are wanted. Usually n will be returned, but modules don’t necessarily have to.
device (torch.device) – The device to do tracing on. If
None
(default), will be guessed.dtype (torch.dtype) – The dtype to trace with. If
None
(default), will be guessed.
- Returns
Tracing inputs in the format of
torch.jit.trace_module
: dicts mapping method names like'forward'
to tuples of arguments.- Return type
list of dict
- e3nn.util.jit.script(mod: torch.nn.modules.module.Module, in_place: bool = True)[source]
Script a module.
Like
torch.jit.script
, but first recursively compilesmod
using :func:compile
.- Parameters
mod (torch.nn.Module) –
- Return type
Scripted module.
- e3nn.util.jit.trace(mod: torch.nn.modules.module.Module, example_inputs: Optional[tuple] = None, check_inputs: Optional[list] = None, in_place: bool = True)[source]
Trace a module.
Identical signature to
torch.jit.trace
, but first recursively compilesmod
using :func:compile
.- Parameters
mod (torch.nn.Module) –
example_inputs (tuple) –
check_inputs (list of tuple) –
- Return type
Traced module.
- e3nn.util.jit.trace_module(mod: torch.nn.modules.module.Module, inputs: Optional[dict] = None, check_inputs: Optional[list] = None, in_place: bool = True)[source]
Trace a module.
Identical signature to
torch.jit.trace_module
, but first recursively compilesmod
usingcompile
.- Parameters
mod (torch.nn.Module) –
inputs (dict) –
check_inputs (list of dict) –
- Return type
Traced module.