JIT - wrappers for TorchScript

Functions:

compile(mod[, n_trace_checks, ...])

Recursively compile a module and all submodules according to their decorators.

compile_mode(mode)

Decorator to set the compile mode of a module.

get_compile_mode(mod)

Get the compilation mode of a module.

get_tracing_inputs(mod[, n, device, dtype])

Get random tracing inputs for mod.

script(mod[, in_place])

Script a module.

trace(mod[, example_inputs, check_inputs, ...])

Trace a module.

trace_module(mod[, inputs, check_inputs, ...])

Trace a module.

e3nn.util.jit.compile(mod: Module, n_trace_checks: int = 1, script_options: Optional[dict] = None, trace_options: Optional[dict] = None, in_place: bool = True)[source]

Recursively compile a module and all submodules according to their decorators.

(Sub)modules without decorators will be unaffected.

Parameters
  • mod (torch.nn.Module) – The module to compile. The module will have its submodules compiled replaced in-place.

  • n_trace_checks (int, default = 1) – How many random example inputs to generate when tracing a module. Must be at least one in order to have a tracing input. Extra example inputs will be pased to torch.jit.trace to confirm that the traced copmute graph doesn’t change.

  • script_options (dict, default = {}) – Extra kwargs for torch.jit.script.

  • trace_options (dict, default = {}) – Extra kwargs for torch.jit.trace.

Return type

Returns the compiled module.

e3nn.util.jit.compile_mode(mode: str)[source]

Decorator to set the compile mode of a module.

Parameters

mode (str) – ‘script’, ‘trace’, or None

e3nn.util.jit.get_compile_mode(mod: Module) str[source]

Get the compilation mode of a module.

Parameters

mod (torch.nn.Module) –

Return type

‘script’, ‘trace’, or None if the module was not decorated with @compile_mode

e3nn.util.jit.get_tracing_inputs(mod: Module, n: int = 1, device: Optional[device] = None, dtype: Optional[dtype] = None)[source]

Get random tracing inputs for mod.

First checks if mod has a _make_tracing_inputs method. If so, calls it with n as the single argument and returns its results.

Otherwise, attempts to infer the input signature of the module using e3nn.util._argtools._get_io_irreps.

Parameters
  • mod (torch.nn.Module) –

  • n (int, default = 1) – A hint for how many inputs are wanted. Usually n will be returned, but modules don’t necessarily have to.

  • device (torch.device) – The device to do tracing on. If None (default), will be guessed.

  • dtype (torch.dtype) – The dtype to trace with. If None (default), will be guessed.

Returns

Tracing inputs in the format of torch.jit.trace_module: dicts mapping method names like 'forward' to tuples of arguments.

Return type

list of dict

e3nn.util.jit.script(mod: Module, in_place: bool = True)[source]

Script a module.

Like torch.jit.script, but first recursively compiles mod using :func:compile.

Parameters

mod (torch.nn.Module) –

Return type

Scripted module.

e3nn.util.jit.trace(mod: Module, example_inputs: Optional[tuple] = None, check_inputs: Optional[list] = None, in_place: bool = True)[source]

Trace a module.

Identical signature to torch.jit.trace, but first recursively compiles mod using :func:compile.

Parameters
Return type

Traced module.

e3nn.util.jit.trace_module(mod: Module, inputs: Optional[dict] = None, check_inputs: Optional[list] = None, in_place: bool = True)[source]

Trace a module.

Identical signature to torch.jit.trace_module, but first recursively compiles mod using compile.

Parameters
Return type

Traced module.