test - helpers for unit testing

Functions:

assert_auto_jitable(func[, ...])

Assert that submodule func is automatically JITable.

assert_equivariant(func[, args_in, ...])

Assert that func is equivariant.

assert_normalized(func[, irreps_in, ...])

Assert that func is normalized.

equivariance_error(func, args_in[, ...])

Get the maximum equivariance error for func over ntrials

format_equivariance_error(errors)

Format the dictionary returned by equivariance_error into a readable string.

random_irreps([n, lmax, mul_min, mul_max, ...])

Generate random irreps parameters for testing.

set_random_seeds()

Set the random seeds to try to get some reproducibility

e3nn.util.test.assert_auto_jitable(func, error_on_warnings=True, n_trace_checks=2, strict_shapes=True)[source]

Assert that submodule func is automatically JITable.

Parameters
  • func (Callable) – The function to trace.

  • error_on_warnings (bool) – If True (default), TracerWarnings emitted by torch.jit.trace will be treated as errors.

  • n_random_tests (int) – If args_in is None and arguments are being automatically generated, this many random arguments will be generated as test inputs for torch.jit.trace.

  • strict_shapes (bool) – Test that the traced function errors on inputs with feature dimensions that don’t match the input irreps.

Return type

The traced TorchScript function.

e3nn.util.test.assert_equivariant(func, args_in=None, irreps_in=None, irreps_out=None, tolerance=None, **kwargs) dict[source]

Assert that func is equivariant.

Parameters
  • args_in (list or None) – the original input arguments for the function. If None and the function has irreps_in consisting only of o3.Irreps and 'cartesian', random test inputs will be generated.

  • irreps_in (object) – see equivariance_error

  • irreps_out (object) – see equivariance_error

  • tolerance (float or None) – the threshold below which the equivariance error must fall. If None, (the default), FLOAT_TOLERANCE[torch.get_default_dtype()] is used.

  • **kwargs (kwargs) – passed through to equivariance_error.

Returns

The same as ``equivariance_error``

Return type

a dictionary mapping tuples (parity_k, did_translate) to errors

e3nn.util.test.assert_normalized(func: Module, irreps_in=None, irreps_out=None, normalization: str = 'component', n_input: int = 10000, n_weight: Optional[int] = None, weights: Optional[Iterable[Parameter]] = None, atol: float = 0.1) None[source]

Assert that func is normalized.

See https://docs.e3nn.org/en/stable/guide/normalization.html for more information on the normalization scheme.

atol, n_input, and n_weight may need to be significantly higher in order to converge the statistics to pass the test.

Parameters
  • func (torch.nn.Module) – the module to test

  • irreps_in (object) – see equivariance_error

  • irreps_out (object) – see equivariance_error

  • normalization (str, default "component") – one of “component” or “norm”. Note that this is defined for both the inputs and the outputs; if you need seperate normalizations for input and output please file a feature request.

  • n_input (int, default 10_000) – the number of input samples to use for each weight init

  • n_weight (int, default 20) – the number of weight initializations to sample

  • weights (optional iterable of parameters) – the weights to reinitialize n_weight times. If None (default), func.parameters() will be used.

  • atol (float, default 0.1) – tolerance for checking moments. Higher values for this prevent explosive computational costs for this test.

e3nn.util.test.equivariance_error(func, args_in, irreps_in=None, irreps_out=None, ntrials=1, do_parity=True, do_translation=True, transform_dtype=torch.float64)[source]

Get the maximum equivariance error for func over ntrials

Each trial randomizes the equivariant transformation tested.

Parameters
  • func (callable) – the function to test

  • args_in (list) – the original inputs to pass to func.

  • irreps_in (list of e3nn.o3.Irreps or e3nn.o3.Irreps) – the input irreps for each of the arguments in args_in. If left as the default of None, get_io_irreps will be used to try to infer them. If a sequence is provided, valid elements are also the string 'cartesian', which denotes that the corresponding input should be dealt with as cartesian points in 3D, and None, which indicates that the argument should not be transformed.

  • irreps_out (list of e3nn.o3.Irreps or e3nn.o3.Irreps) – the out irreps for each of the return values of func. Accepts similar values to irreps_in.

  • ntrials (int) – run this many trials with random transforms

  • do_parity (bool) – whether to test parity

  • do_translation (bool) – whether to test translation for 'cartesian' inputs

Returns

  • dictionary mapping tuples (parity_k, did_translate) to an array of errors,

  • each entry the biggest over all trials for that output, in order.

e3nn.util.test.format_equivariance_error(errors: dict) str[source]

Format the dictionary returned by equivariance_error into a readable string.

Parameters

errors (dict) – A dictionary of errors returned by equivariance_error.

Return type

A string.

e3nn.util.test.random_irreps(n: int = 1, lmax: int = 4, mul_min: int = 0, mul_max: int = 5, len_min: int = 0, len_max: int = 4, clean: bool = False, allow_empty: bool = True)[source]

Generate random irreps parameters for testing.

Parameters
  • n (int, optional) – How many to generate; defaults to 1.

  • lmax (int, optional) – The maximum L to generate (inclusive); defaults to 4.

  • mul_min (int, optional) – The smallest multiplicity to generate, defaults to 0.

  • mul_max (int, optional) – The largest multiplicity to generate, defaults to 5.

  • len_min (int, optional) – The smallest number of irreps to generate, defaults to 0.

  • len_max (int, optional) – The largest number of irreps to generate, defaults to 4.

  • clean (bool, optional) – If True, only o3.Irreps objects will be returned. If False (the default), e3nn.o3.Irreps-like objects like strings and lists of tuples can be returned.

  • allow_empty (bool, optional) – Whether to allow generating empty e3nn.o3.Irreps.

Return type

An irreps-like object if n == 1 or a list of them if n > 1

e3nn.util.test.set_random_seeds()[source]

Set the random seeds to try to get some reproducibility