From b4fcb38e63eb58f3376e7694be5450c4c4222d1a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 3 Oct 2022 11:42:06 +0100 Subject: [PATCH 1/3] Make elwise assertion util kwargs kw-only --- array_api_tests/test_operators_and_elementwise_functions.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index d4349372..c7cf33b4 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -107,7 +107,7 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: # floats are used internally for optimisation or legacy reasons. -def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: +def isclose(a: float, b: float, *, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: """Wraps math.isclose with very generous defaults. This is useful for many floating-point operations where the spec does not @@ -137,6 +137,7 @@ def unary_assert_against_refimpl( in_: Array, res: Array, refimpl: Callable[[T], T], + *, expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, @@ -184,6 +185,7 @@ def binary_assert_against_refimpl( right: Array, res: Array, refimpl: Callable[[T, T], T], + *, expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, left_sym: str = "x1", @@ -234,6 +236,7 @@ def right_scalar_assert_against_refimpl( right: Scalar, res: Array, refimpl: Callable[[T, T], T], + *, expr_template: str = None, res_stype: Optional[ScalarType] = None, left_sym: str = "x1", @@ -486,6 +489,7 @@ def binary_param_assert_against_refimpl( res: Array, op_sym: str, refimpl: Callable[[T, T], T], + *, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, strict_check: Optional[bool] = None, From 5b26f9da3f517c472ee49fc1913c7aaa32852307 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 3 Oct 2022 13:09:50 +0100 Subject: [PATCH 2/3] Move format kwargs to the end of elwise assert utils --- .../test_operators_and_elementwise_functions.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index c7cf33b4..53993741 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -138,10 +138,10 @@ def unary_assert_against_refimpl( res: Array, refimpl: Callable[[T], T], *, - expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, strict_check: Optional[bool] = None, + expr_template: Optional[str] = None, ): if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") @@ -186,13 +186,13 @@ def binary_assert_against_refimpl( res: Array, refimpl: Callable[[T, T], T], *, - expr_template: Optional[str] = None, res_stype: Optional[ScalarType] = None, + filter_: Callable[[Scalar], bool] = default_filter, + strict_check: Optional[bool] = None, left_sym: str = "x1", right_sym: str = "x2", res_name: str = "out", - filter_: Callable[[Scalar], bool] = default_filter, - strict_check: Optional[bool] = None, + expr_template: Optional[str] = None, ): if expr_template is None: expr_template = func_name + "({}, {})={}" @@ -237,12 +237,12 @@ def right_scalar_assert_against_refimpl( res: Array, refimpl: Callable[[T, T], T], *, - expr_template: str = None, res_stype: Optional[ScalarType] = None, - left_sym: str = "x1", - res_name: str = "out", filter_: Callable[[Scalar], bool] = default_filter, strict_check: Optional[bool] = None, + left_sym: str = "x1", + res_name: str = "out", + expr_template: str = None, ): if filter_(right): return # short-circuit here as there will be nothing to test From a943a5172c0eaa8fd83a46583e238e4564cbbd9d Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 4 Oct 2022 13:22:29 +0100 Subject: [PATCH 3/3] Document major elwise testing helpers --- ...est_operators_and_elementwise_functions.py | 157 +++++++++++++++--- 1 file changed, 132 insertions(+), 25 deletions(-) diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 53993741..967a43a6 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1,3 +1,6 @@ +""" +Test element-wise functions/operators against reference implementations. +""" import math import operator from enum import Enum, auto @@ -82,31 +85,6 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: return n -# This module tests elementwise functions/operators against a reference -# implementation. We iterate through the input array(s) and resulting array, -# casting the indexed arrays to Python scalars and calculating the expected -# output with `refimpl` function. -# -# This is finicky to refactor, but possible and ultimately worthwhile - hence -# why these *_assert_again_refimpl() utilities exist. -# -# Values which are special-cased are generated and passed, but are filtered by -# the `filter_` callable before they can be asserted against `refimpl`. We -# automatically generate tests for special cases in the special_cases/ dir. We -# still pass them here so as to ensure their presence doesn't affect the outputs -# respective to non-special-cased elements. -# -# By default, results are casted to scalars the same way that the inputs are. -# You can specify a cast via `res_stype, i.e. when a function accepts numerical -# inputs but returns boolean arrays. -# -# By default, floating-point functions/methods are loosely asserted against. Use -# `strict_check=True` when they should be strictly asserted against, i.e. -# when a function should return intergrals. Likewise, use `strict_check=False` -# when integer function/methods should be loosely asserted against, i.e. when -# floats are used internally for optimisation or legacy reasons. - - def isclose(a: float, b: float, *, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: """Wraps math.isclose with very generous defaults. @@ -143,6 +121,125 @@ def unary_assert_against_refimpl( strict_check: Optional[bool] = None, expr_template: Optional[str] = None, ): + """ + Assert unary element-wise results are as expected. + + We iterate through every element in the input and resulting arrays, casting + the respective elements (0-D arrays) to Python scalars, and assert against + the expected output specified by the passed reference implementation, e.g. + + >>> x = xp.asarray([[0, 1], [2, 4]]) + >>> out = xp.square(x) + >>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2) + + is equivalent to + + >>> for idx in np.ndindex(x.shape): + ... expected = int(x[idx]) ** 2 + ... assert int(out[idx]) == expected + + Casting + ------- + + The input scalar type is inferred from the input array's dtype like so + + Array dtypes | Python builtin type + ----------------- | --------------------- + xp.bool | bool + xp.int*, xp.uint* | int + xp.float* | float + xp.complex* | complex + + If res_stype=None (the default), the result scalar type is the same as the + input scalar type. We can also specify the result scalar type ourselves, e.g. + + >>> x = xp.asarray([42., xp.inf]) + >>> out = xp.isinf(x) # should be [False, True] + >>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool) + + Filtering special-cased values + ------------------------------ + + Values which are special-cased can be present in the input array, but get + filtered before they can be asserted against refimpl. + + If filter_=default_filter (the default), all non-finite and floating zero + values are filtered, e.g. + + >>> unary_assert_against_refimpl('sin', x, out, math.sin) + + is equivalent to + + >>> for idx in np.ndindex(x.shape): + ... at_x = float(x[idx]) + ... if math.isfinite(at_x) or at_x != 0: + ... expected = math.sin(at_x) + ... assert math.isclose(float(out[idx]), expected) + + We can also specify the filter function ourselves, e.g. + + >>> def sqrt_filter(s: float) -> bool: + ... return math.isfinite(s) and s >= 0 + >>> unary_assert_against_refimpl('sqrt', x, out, math.sqrt, filter_=sqrt_filter) + + is equivalent to + + >>> for idx in np.ndindex(x.shape): + ... at_x = float(x[idx]) + ... if math.isfinite(s) and s >=0: + ... expected = math.sin(at_x) + ... assert math.isclose(float(out[idx]), expected) + + Note we leave special-cased values in the input arrays, so as to ensure + their presence doesn't affect the outputs respective to non-special-cased + elements. We specifically test special case bevaiour in test_special_cases.py. + + Assertion strictness + -------------------- + + If strict_check=None (the default), integer elements are strictly asserted + against, and floating elements are loosely asserted against, e.g. + + >>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2) + + is equivalent to + + >>> for idx in np.ndindex(x.shape): + ... expected = in_stype(x[idx]) ** 2 + ... if in_stype == int: + ... assert int(out[idx]) == expected + ... else: # in_stype == float + ... assert math.isclose(float(out[idx]), expected) + + Specifying strict_check as True or False will assert strictly/loosely + respectively, regardless of dtype. This is useful for testing functions that + have definitive outputs for floating inputs, i.e. rounding functions. + + Expressions in errors + --------------------- + + Assertion error messages include an expression, by default using func_name + like so + + >>> x = xp.asarray([42., xp.inf]) + >>> out = xp.isinf(x) + >>> out + [False, False] + >>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool) + AssertionError: out[1]=False, but should be isinf(x[1])=True ... + + We can specify the expression template ourselves, e.g. + + >>> x = xp.asarray(True) + >>> out = xp.logical_not(x) + >>> out + True + >>> unary_assert_against_refimpl( + ... 'logical_not', x, out, expr_template='(not {})={}' + ... ) + AssertionError: out=True, but should be (not True)=False ... + + """ if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") if expr_template is None: @@ -194,6 +291,11 @@ def binary_assert_against_refimpl( res_name: str = "out", expr_template: Optional[str] = None, ): + """ + Assert binary element-wise results are as expected. + + See unary_assert_against_refimpl for more information. + """ if expr_template is None: expr_template = func_name + "({}, {})={}" in_stype = dh.get_scalar_type(left.dtype) @@ -244,6 +346,11 @@ def right_scalar_assert_against_refimpl( res_name: str = "out", expr_template: str = None, ): + """ + Assert binary element-wise results from scalar operands are as expected. + + See unary_assert_against_refimpl for more information. + """ if filter_(right): return # short-circuit here as there will be nothing to test in_stype = dh.get_scalar_type(left.dtype)