test - helpers for unit testing
Functions:
|
Assert that submodule |
|
Assert that |
|
Assert that |
|
Assert that func is torch.compile(fullgraph=True) |
|
Get the maximum equivariance error for |
|
Format the dictionary returned by |
|
Generate random irreps parameters for testing. |
Set the random seeds to try to get some reproducibility |
- e3nn.util.test.assert_auto_jitable(func, error_on_warnings: bool = True, n_trace_checks: int = 2, strict_shapes: bool = True)[source]
Assert that submodule
funcis automatically JITable.- Parameters:
func (Callable) – The function to trace.
error_on_warnings (bool) – If True (default), TracerWarnings emitted by
torch.jit.tracewill be treated as errors.n_random_tests (int) – If
args_inisNoneand arguments are being automatically generated, this many random arguments will be generated as test inputs fortorch.jit.trace.strict_shapes (bool) – Test that the traced function errors on inputs with feature dimensions that don’t match the input irreps.
- Return type:
The traced TorchScript function.
- e3nn.util.test.assert_equivariant(func, args_in=None, irreps_in=None, irreps_out=None, tolerance=None, **kwargs) dict[source]
Assert that
funcis equivariant.- Parameters:
args_in (list or None) – the original input arguments for the function. If
Noneand the function hasirreps_inconsisting only ofo3.Irrepsand'cartesian', random test inputs will be generated.irreps_in (object) – see
equivariance_errorirreps_out (object) – see
equivariance_errortolerance (float or None) – the threshold below which the equivariance error must fall. If
None, (the default),FLOAT_TOLERANCE[torch.get_default_dtype()]is used.**kwargs (kwargs) – passed through to
equivariance_error.
- Returns:
The same as ``equivariance_error``
- Return type:
a dictionary mapping tuples
(parity_k, did_translate)to errors
- e3nn.util.test.assert_normalized(func: Module, irreps_in=None, irreps_out=None, normalization: str = 'component', n_input: int = 10000, n_weight: int | None = None, weights: Iterable[Parameter] | None = None, atol: float = 0.1) None[source]
Assert that
funcis normalized.See https://docs.e3nn.org/en/stable/guide/normalization.html for more information on the normalization scheme.
atol,n_input, andn_weightmay need to be significantly higher in order to converge the statistics to pass the test.- Parameters:
func (torch.nn.Module) – the module to test
irreps_in (object) – see
equivariance_errorirreps_out (object) – see
equivariance_errornormalization (str, default "component") – one of “component” or “norm”. Note that this is defined for both the inputs and the outputs; if you need seperate normalizations for input and output please file a feature request.
n_input (int, default 10_000) – the number of input samples to use for each weight init
n_weight (int, default 20) – the number of weight initializations to sample
weights (optional iterable of parameters) – the weights to reinitialize
n_weighttimes. IfNone(default),func.parameters()will be used.atol (float, default 0.1) – tolerance for checking moments. Higher values for this prevent explosive computational costs for this test.
- e3nn.util.test.assert_torch_compile(compile_mode: str, func: Callable, *args, **kwargs) None[source]
Assert that func is torch.compile(fullgraph=True)
- Parameters:
func (Callable thats a functools.partial(torch.nn.Module))
*args (func's forward arguments)
**kwargs (func's forward positional arguments)
- e3nn.util.test.equivariance_error(func, args_in, irreps_in=None, irreps_out=None, ntrials: int = 1, do_parity: bool = True, do_translation: bool = True, transform_dtype=torch.float64)[source]
Get the maximum equivariance error for
funcoverntrialsEach trial randomizes the equivariant transformation tested.
- Parameters:
func (callable) – the function to test
args_in (list) – the original inputs to pass to
func.irreps_in (list of
e3nn.o3.Irrepsore3nn.o3.Irreps) – the input irreps for each of the arguments inargs_in. If left as the default ofNone,get_io_irrepswill be used to try to infer them. If a sequence is provided, valid elements are also the string'cartesian', which denotes that the corresponding input should be dealt with as cartesian points in 3D, andNone, which indicates that the argument should not be transformed.irreps_out (list of
e3nn.o3.Irrepsore3nn.o3.Irreps) – the out irreps for each of the return values offunc. Accepts similar values toirreps_in.ntrials (int) – run this many trials with random transforms
do_parity (bool) – whether to test parity
do_translation (bool) – whether to test translation for
'cartesian'inputs
- Returns:
dictionary mapping tuples
(parity_k, did_translate)to an array of errors,each entry the biggest over all trials for that output, in order.
- e3nn.util.test.format_equivariance_error(errors: dict) str[source]
Format the dictionary returned by
equivariance_errorinto a readable string.- Parameters:
errors (dict) – A dictionary of errors returned by
equivariance_error.- Return type:
A string.
- e3nn.util.test.random_irreps(n: int = 1, lmax: int = 4, mul_min: int = 0, mul_max: int = 5, len_min: int = 0, len_max: int = 4, clean: bool = False, allow_empty: bool = True)[source]
Generate random irreps parameters for testing.
- Parameters:
n (int, optional) – How many to generate; defaults to 1.
lmax (int, optional) – The maximum L to generate (inclusive); defaults to 4.
mul_min (int, optional) – The smallest multiplicity to generate, defaults to 0.
mul_max (int, optional) – The largest multiplicity to generate, defaults to 5.
len_min (int, optional) – The smallest number of irreps to generate, defaults to 0.
len_max (int, optional) – The largest number of irreps to generate, defaults to 4.
clean (bool, optional) – If
True, onlyo3.Irrepsobjects will be returned. IfFalse(the default),e3nn.o3.Irreps-like objects like strings and lists of tuples can be returned.allow_empty (bool, optional) – Whether to allow generating empty
e3nn.o3.Irreps.
- Return type:
An irreps-like object if
n == 1or a list of them ifn > 1