test - helpers for unit testing
Functions:
|
Assert that submodule |
|
Assert that |
|
Assert that |
|
Get the maximum equivariance error for |
|
Format the dictionary returned by |
|
Generate random irreps parameters for testing. |
Set the random seeds to try to get some reproducibility |
- e3nn.util.test.assert_auto_jitable(func, error_on_warnings: bool = True, n_trace_checks: int = 2, strict_shapes: bool = True)[source]
Assert that submodule
func
is automatically JITable.- Parameters:
func (Callable) – The function to trace.
error_on_warnings (bool) – If True (default), TracerWarnings emitted by
torch.jit.trace
will be treated as errors.n_random_tests (int) – If
args_in
isNone
and arguments are being automatically generated, this many random arguments will be generated as test inputs fortorch.jit.trace
.strict_shapes (bool) – Test that the traced function errors on inputs with feature dimensions that don’t match the input irreps.
- Return type:
The traced TorchScript function.
- e3nn.util.test.assert_equivariant(func, args_in=None, irreps_in=None, irreps_out=None, tolerance=None, **kwargs) dict [source]
Assert that
func
is equivariant.- Parameters:
args_in (list or None) – the original input arguments for the function. If
None
and the function hasirreps_in
consisting only ofo3.Irreps
and'cartesian'
, random test inputs will be generated.irreps_in (object) – see
equivariance_error
irreps_out (object) – see
equivariance_error
tolerance (float or None) – the threshold below which the equivariance error must fall. If
None
, (the default),FLOAT_TOLERANCE[torch.get_default_dtype()]
is used.**kwargs (kwargs) – passed through to
equivariance_error
.
- Returns:
The same as ``equivariance_error``
- Return type:
a dictionary mapping tuples
(parity_k, did_translate)
to errors
- e3nn.util.test.assert_normalized(func: Module, irreps_in=None, irreps_out=None, normalization: str = 'component', n_input: int = 10000, n_weight: int | None = None, weights: Iterable[Parameter] | None = None, atol: float = 0.1) None [source]
Assert that
func
is normalized.See https://docs.e3nn.org/en/stable/guide/normalization.html for more information on the normalization scheme.
atol
,n_input
, andn_weight
may need to be significantly higher in order to converge the statistics to pass the test.- Parameters:
func (torch.nn.Module) – the module to test
irreps_in (object) – see
equivariance_error
irreps_out (object) – see
equivariance_error
normalization (str, default "component") – one of “component” or “norm”. Note that this is defined for both the inputs and the outputs; if you need seperate normalizations for input and output please file a feature request.
n_input (int, default 10_000) – the number of input samples to use for each weight init
n_weight (int, default 20) – the number of weight initializations to sample
weights (optional iterable of parameters) – the weights to reinitialize
n_weight
times. IfNone
(default),func.parameters()
will be used.atol (float, default 0.1) – tolerance for checking moments. Higher values for this prevent explosive computational costs for this test.
- e3nn.util.test.equivariance_error(func, args_in, irreps_in=None, irreps_out=None, ntrials: int = 1, do_parity: bool = True, do_translation: bool = True, transform_dtype=torch.float64)[source]
Get the maximum equivariance error for
func
overntrials
Each trial randomizes the equivariant transformation tested.
- Parameters:
func (callable) – the function to test
args_in (list) – the original inputs to pass to
func
.irreps_in (list of
e3nn.o3.Irreps
ore3nn.o3.Irreps
) – the input irreps for each of the arguments inargs_in
. If left as the default ofNone
,get_io_irreps
will be used to try to infer them. If a sequence is provided, valid elements are also the string'cartesian'
, which denotes that the corresponding input should be dealt with as cartesian points in 3D, andNone
, which indicates that the argument should not be transformed.irreps_out (list of
e3nn.o3.Irreps
ore3nn.o3.Irreps
) – the out irreps for each of the return values offunc
. Accepts similar values toirreps_in
.ntrials (int) – run this many trials with random transforms
do_parity (bool) – whether to test parity
do_translation (bool) – whether to test translation for
'cartesian'
inputs
- Returns:
dictionary mapping tuples
(parity_k, did_translate)
to an array of errors,each entry the biggest over all trials for that output, in order.
- e3nn.util.test.format_equivariance_error(errors: dict) str [source]
Format the dictionary returned by
equivariance_error
into a readable string.- Parameters:
errors (dict) – A dictionary of errors returned by
equivariance_error
.- Return type:
A string.
- e3nn.util.test.random_irreps(n: int = 1, lmax: int = 4, mul_min: int = 0, mul_max: int = 5, len_min: int = 0, len_max: int = 4, clean: bool = False, allow_empty: bool = True)[source]
Generate random irreps parameters for testing.
- Parameters:
n (int, optional) – How many to generate; defaults to 1.
lmax (int, optional) – The maximum L to generate (inclusive); defaults to 4.
mul_min (int, optional) – The smallest multiplicity to generate, defaults to 0.
mul_max (int, optional) – The largest multiplicity to generate, defaults to 5.
len_min (int, optional) – The smallest number of irreps to generate, defaults to 0.
len_max (int, optional) – The largest number of irreps to generate, defaults to 4.
clean (bool, optional) – If
True
, onlyo3.Irreps
objects will be returned. IfFalse
(the default),e3nn.o3.Irreps
-like objects like strings and lists of tuples can be returned.allow_empty (bool, optional) – Whether to allow generating empty
e3nn.o3.Irreps
.
- Return type:
An irreps-like object if
n == 1
or a list of them ifn > 1