From 3982ad954a876dea0f02de011dd2eaa7e2cd0c77 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Jul 2023 11:34:26 +0200 Subject: [PATCH 1/4] Use fstrings in Alloc.c_code --- pytensor/tensor/basic.py | 44 +++++++++++++++++----------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 1f43d5531d..5837fc3435 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1494,36 +1494,30 @@ def c_code(self, node, name, inp, out, sub): # Initialize shape for i, shp_i in enumerate(inp[1:]): - code += """ - shape[%(i)s] = ((dtype_%(shp_i)s*) PyArray_DATA(%(shp_i)s))[0]; - """ % dict( - i=i, shp_i=shp_i - ) + code += f""" + shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0]; + """ - code += """ - int need_new_out = (NULL == %(zz)s); - for (int i = 0; i < %(ndim)s; i++) - need_new_out = (need_new_out - || (PyArray_DIMS(%(zz)s)[i] != shape[i])); + code += f""" + int need_new_out = (NULL == {zz}); + for (int i = 0; i < {ndim}; i++) + need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i])); if (need_new_out) - { - Py_XDECREF(%(zz)s); - %(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s, - shape, PyArray_TYPE((PyArrayObject*) py_%(vv)s)); - if (!%(zz)s) - { + {{ + Py_XDECREF({zz}); + {zz} = (PyArrayObject*) PyArray_SimpleNew({ndim}, shape, PyArray_TYPE((PyArrayObject*) py_{vv})); + if (!{zz}) + {{ PyErr_SetString(PyExc_MemoryError, "alloc failed"); - %(fail)s - } - } + {fail} + }} + }} // This function takes care of broadcasting - if (PyArray_CopyInto(%(zz)s, %(vv)s) == -1) - %(fail)s - """ % dict( - vv=vv, ndim=ndim, zz=zz, fail=fail - ) + if (PyArray_CopyInto({zz}, {vv}) == -1) + {fail} + """ return code @@ -1568,7 +1562,7 @@ def grad(self, inputs, grads): for idx, axis in enumerate(axis_kept): new_order[axis] = idx gx = gx.dimshuffle(new_order) - # Dimshuffle to add back the broadcasted dims + # Dimshuffle to add back the broadcasted dims # The *elements* of the output are not connected to # the inputs that specify the shape. If you grow the # shape by epsilon, the existing elements do not From 38c183cf437df37dd864b101b887141177c961a2 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Jul 2023 11:25:47 +0200 Subject: [PATCH 2/4] Fix segfault bug in Alloc C implementation --- pytensor/tensor/basic.py | 4 ++-- tests/tensor/test_basic.py | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 5837fc3435..b840476ded 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1506,7 +1506,7 @@ def c_code(self, node, name, inp, out, sub): if (need_new_out) {{ Py_XDECREF({zz}); - {zz} = (PyArrayObject*) PyArray_SimpleNew({ndim}, shape, PyArray_TYPE((PyArrayObject*) py_{vv})); + {zz} = (PyArrayObject*) PyArray_SimpleNew({ndim}, shape, PyArray_TYPE({vv})); if (!{zz}) {{ PyErr_SetString(PyExc_MemoryError, "alloc failed"); @@ -1522,7 +1522,7 @@ def c_code(self, node, name, inp, out, sub): return code def c_code_cache_version(self): - return (2,) + return (3,) def infer_shape(self, fgraph, node, input_shapes): return [node.inputs[1:]] diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index bb2e918278..b1a0e4b440 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -14,7 +14,7 @@ from pytensor.compile.mode import Mode, get_default_mode from pytensor.compile.ops import DeepCopyOp from pytensor.gradient import grad, hessian -from pytensor.graph.basic import Apply +from pytensor.graph.basic import Apply, equal_computations from pytensor.graph.op import Op from pytensor.graph.replace import clone_replace from pytensor.misc.safe_asarray import _asarray @@ -720,6 +720,7 @@ class TestAlloc: shared = staticmethod(pytensor.shared) allocs = [Alloc()] * 3 + def setup_method(self): self.rng = np.random.default_rng(seed=utt.fetch_seed()) @@ -851,6 +852,19 @@ def test_static_shape(self): with pytest.raises(ValueError, match=msg): at.alloc(x, 3, 1, 6) + def test_alloc_of_view_linker(self): + """Check we can allocate a new array properly in the C linker when input is a view.""" + x_v = vector("x", shape=(None,)) + dim_len = scalar("dim_len", dtype=int) + out = alloc(specify_shape(x_v, (1,)), 5, dim_len) + + f = pytensor.function([x_v, dim_len], out, mode=Mode("c")) + assert equal_computations( + f.maker.fgraph.outputs, [alloc(specify_shape(x_v, (1,)), 5, dim_len)] + ) + + np.testing.assert_array_equal(f(x=np.zeros((1,)), dim_len=3), np.zeros((5, 3))) + def test_infer_shape(): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): From 8f5d5fa9e711dabedd443c508752619981c05eec Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Jul 2023 12:23:14 +0200 Subject: [PATCH 3/4] Rename Elemwise.check_runtime_broadcast --- tests/link/jax/test_elemwise.py | 4 ++-- tests/link/numba/test_elemwise.py | 4 ++-- tests/tensor/test_elemwise.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/link/jax/test_elemwise.py b/tests/link/jax/test_elemwise.py index dab0a750b9..7d13d1aaf7 100644 --- a/tests/link/jax/test_elemwise.py +++ b/tests/link/jax/test_elemwise.py @@ -18,8 +18,8 @@ from tests.tensor.test_elemwise import TestElemwise -def test_elemwise_runtime_shape_error(): - TestElemwise.check_runtime_shapes_error(get_mode("JAX")) +def test_elemwise_runtime_broadcast(): + TestElemwise.check_runtime_broadcast(get_mode("JAX")) def test_jax_Dimshuffle(): diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 87e5968d4e..ffc7967c83 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -122,8 +122,8 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): @pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults") -def test_elemwise_runtime_shape_error(): - TestElemwise.check_runtime_shapes_error(get_mode("NUMBA")) +def test_elemwise_runtime_broadcast(): + TestElemwise.check_runtime_broadcast(get_mode("NUMBA")) def test_elemwise_speed(benchmark): diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 52f34c7be0..9f0c3a5976 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -751,7 +751,7 @@ def test_input_dimensions_overflow(self): g(*[np.zeros(2**11, config.floatX) for i in range(6)]) @staticmethod - def check_runtime_shapes_error(mode): + def check_runtime_broadcast(mode): """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" x_v = matrix("x") m_v = vector("m") @@ -777,15 +777,15 @@ def check_runtime_shapes_error(mode): with pytest.raises((ValueError, TypeError)): f(x, m) - def test_runtime_shapes_error_python(self): - self.check_runtime_shapes_error(Mode(linker="py")) + def test_runtime_broadcast_python(self): + self.check_runtime_broadcast(Mode(linker="py")) @pytest.mark.skipif( not pytensor.config.cxx, reason="G++ not available, so we need to skip this test.", ) - def test_runtime_shapes_error_c(self): - self.check_runtime_shapes_error(Mode(linker="c")) + def test_runtime_broadcast_c(self): + self.check_runtime_broadcast(Mode(linker="c")) def test_str(self): op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None) From ca9082abb8c932b69c1129752d6409fce270a7db Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 19 Jul 2023 12:01:47 +0200 Subject: [PATCH 4/4] Forbid runtime broadcasting in Alloc --- pytensor/link/jax/dispatch/tensor_basic.py | 3 +- pytensor/link/numba/dispatch/tensor_basic.py | 10 ++++- pytensor/tensor/basic.py | 47 +++++++++++++++++--- tests/link/jax/test_tensor_basic.py | 7 +++ tests/link/numba/test_tensor_basic.py | 6 +++ tests/tensor/test_basic.py | 43 +++++++++++++++++- 6 files changed, 107 insertions(+), 9 deletions(-) diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index a6b9ee988a..bae86452a7 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -51,9 +51,10 @@ def allocempty(*shape): @jax_funcify.register(Alloc) -def jax_funcify_Alloc(op, **kwargs): +def jax_funcify_Alloc(op, node, **kwargs): def alloc(x, *shape): res = jnp.broadcast_to(x, shape) + Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape) return res return alloc diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index aa063e46dc..6bbf0cd0c7 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -78,16 +78,24 @@ def numba_funcify_Alloc(op, node, **kwargs): " " * 4, ) + check_runtime_broadcast = [] + for i, val_static_dim in enumerate(node.inputs[0].type.shape[::-1]): + if val_static_dim is None: + check_runtime_broadcast.append( + f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")' + ) + check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4) + alloc_def_src = f""" def alloc(val, {", ".join(shape_var_names)}): val_np = np.asarray(val) {shapes_to_items_src} scalar_shape = {create_tuple_string(shape_var_item_names)} +{check_runtime_broadcast_src} res = np.empty(scalar_shape, dtype=val_np.dtype) res[...] = val_np return res """ - alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env}) return numba_basic.numba_njit(alloc_fn) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index b840476ded..f77ebe7687 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1431,6 +1431,12 @@ class Alloc(COp): __props__ = () + _runtime_broadcast_error_msg = ( + "Runtime broadcasting not allowed. " + "The output of Alloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. " + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." + ) + def make_node(self, value, *shape): value = as_tensor_variable(value) shape, static_shape = infer_static_shape(shape) @@ -1468,10 +1474,21 @@ def make_node(self, value, *shape): otype = TensorType(dtype=value.dtype, shape=combined_static_shape) return Apply(self, [value] + shape, [otype()]) + @staticmethod + def _check_runtime_broadcast(node, value, shape): + value_static_shape = node.inputs[0].type.shape + for v_static_dim, value_dim, out_dim in zip( + value_static_shape[::-1], value.shape[::-1], shape[::-1] + ): + if v_static_dim is None and value_dim == 1 and out_dim != 1: + raise ValueError(Alloc._runtime_broadcast_error_msg) + def perform(self, node, inputs, out_): (out,) = out_ v = inputs[0] sh = tuple([int(i) for i in inputs[1:]]) + self._check_runtime_broadcast(node, v, sh) + if out[0] is None or out[0].shape != sh: if v.size == 1 and v.item() == 0: out[0] = np.zeros(sh, dtype=v.dtype) @@ -1484,12 +1501,19 @@ def perform(self, node, inputs, out_): def c_code(self, node, name, inp, out, sub): vv = inp[0] - ndim = len(inp[1:]) (zz,) = out fail = sub["fail"] + v_static_shape = node.inputs[0].type.shape + o_static_shape = node.outputs[0].type.shape + v_ndim = len(v_static_shape) + o_ndim = len(o_static_shape) + assert o_ndim == len(inp[1:]) + + # Declare variables code = f""" - npy_intp shape[{ndim}]; + npy_intp shape[{o_ndim}]; + int need_new_out; """ # Initialize shape @@ -1498,15 +1522,26 @@ def c_code(self, node, name, inp, out, sub): shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0]; """ + # Add checks for runtime broadcasting + for i, v_static_dim in enumerate(v_static_shape[::-1]): + if v_static_dim is None: + code += f""" + if (PyArray_DIMS({vv})[{v_ndim - i - 1}] == 1 && shape[{o_ndim - i - 1}] != 1) + {{ + PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}"); + {fail} + }} + """ + code += f""" - int need_new_out = (NULL == {zz}); - for (int i = 0; i < {ndim}; i++) + need_new_out = (NULL == {zz}); + for (int i = 0; i < {o_ndim}; i++) need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i])); if (need_new_out) {{ Py_XDECREF({zz}); - {zz} = (PyArrayObject*) PyArray_SimpleNew({ndim}, shape, PyArray_TYPE({vv})); + {zz} = (PyArrayObject*) PyArray_SimpleNew({o_ndim}, shape, PyArray_TYPE({vv})); if (!{zz}) {{ PyErr_SetString(PyExc_MemoryError, "alloc failed"); @@ -1522,7 +1557,7 @@ def c_code(self, node, name, inp, out, sub): return code def c_code_cache_version(self): - return (3,) + return (4,) def infer_shape(self, fgraph, node, input_shapes): return [node.inputs[1:]] diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index 29dfd152e3..f837c8224a 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from pytensor.compile import get_mode + jax = pytest.importorskip("jax") import jax.errors @@ -12,6 +14,7 @@ from pytensor.graph.op import get_test_value from pytensor.tensor.type import iscalar, matrix, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py +from tests.tensor.test_basic import TestAlloc def test_jax_Alloc(): @@ -50,6 +53,10 @@ def compare_shape_dtype(x, y): compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)]) +def test_alloc_runtime_broadcast(): + TestAlloc.check_runtime_broadcast(get_mode("JAX")) + + def test_jax_MakeVector(): x = at.make_vector(1, 2, 3) x_fg = FunctionGraph([], [x]) diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 047bc18a98..cd7a540c28 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -5,6 +5,7 @@ import pytensor.tensor as at import pytensor.tensor.basic as atb from pytensor import config, function +from pytensor.compile import get_mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph @@ -15,6 +16,7 @@ compare_shape_dtype, set_test_value, ) +from tests.tensor.test_basic import TestAlloc rng = np.random.default_rng(42849) @@ -45,6 +47,10 @@ def test_Alloc(v, shape): assert numba_res.shape == shape +def test_alloc_runtime_broadcast(): + TestAlloc.check_runtime_broadcast(get_mode("NUMBA")) + + def test_AllocEmpty(): x = at.empty((2, 3), dtype="float32") x_fg = FunctionGraph([], [x]) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index b1a0e4b440..fb68f2d939 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -720,6 +720,38 @@ class TestAlloc: shared = staticmethod(pytensor.shared) allocs = [Alloc()] * 3 + @staticmethod + def check_allocs_in_fgraph(fgraph, n): + assert ( + len([node for node in fgraph.apply_nodes if isinstance(node.op, Alloc)]) + == n + ) + + @staticmethod + def check_runtime_broadcast(mode): + """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" + floatX = config.floatX + x_v = vector("x", shape=(None,)) + + out = alloc(x_v, 5, 3) + f = pytensor.function([x_v], out, mode=mode) + TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) + + np.testing.assert_array_equal( + f(x=np.zeros((3,), dtype=floatX)), + np.zeros((5, 3), dtype=floatX), + ) + with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): + f(x=np.zeros((1,), dtype=floatX)) + + out = alloc(specify_shape(x_v, (1,)), 5, 3) + f = pytensor.function([x_v], out, mode=mode) + TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) + + np.testing.assert_array_equal( + f(x=np.zeros((1,), dtype=floatX)), + np.zeros((5, 3), dtype=floatX), + ) def setup_method(self): self.rng = np.random.default_rng(seed=utt.fetch_seed()) @@ -854,6 +886,8 @@ def test_static_shape(self): def test_alloc_of_view_linker(self): """Check we can allocate a new array properly in the C linker when input is a view.""" + floatX = config.floatX + x_v = vector("x", shape=(None,)) dim_len = scalar("dim_len", dtype=int) out = alloc(specify_shape(x_v, (1,)), 5, dim_len) @@ -863,7 +897,14 @@ def test_alloc_of_view_linker(self): f.maker.fgraph.outputs, [alloc(specify_shape(x_v, (1,)), 5, dim_len)] ) - np.testing.assert_array_equal(f(x=np.zeros((1,)), dim_len=3), np.zeros((5, 3))) + np.testing.assert_array_equal( + f(x=np.zeros((1,), dtype=floatX), dim_len=3), + np.zeros((5, 3), dtype=floatX), + ) + + @pytest.mark.parametrize("mode", (Mode("py"), Mode("c"))) + def test_runtime_broadcast(self, mode): + self.check_runtime_broadcast(mode) def test_infer_shape():