From 8c535d86cc19e6c10f0264f78b3a41eb1c3ec747 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 3 Jul 2023 13:59:15 +0200 Subject: [PATCH 1/4] Remove newlines from Constant __str__ --- pytensor/graph/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 90d054f6ba..0f96556475 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -773,7 +773,7 @@ def signature(self): return (self.type, self.data) def __str__(self): - data_str = str(self.data) + data_str = str(self.data).replace("\n", "") if len(data_str) > 20: data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip() From f5be5627a398bf3fde0922de6b60b19380bb4f1a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 3 Jul 2023 20:00:41 +0200 Subject: [PATCH 2/4] Remove unused variable --- pytensor/graph/rewriting/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytensor/graph/rewriting/utils.py b/pytensor/graph/rewriting/utils.py index 8bf8de87bb..63cc436396 100644 --- a/pytensor/graph/rewriting/utils.py +++ b/pytensor/graph/rewriting/utils.py @@ -45,7 +45,6 @@ def rewrite_graph( return_fgraph = False if isinstance(graph, FunctionGraph): - outputs: Sequence[Variable] = graph.outputs fgraph = graph return_fgraph = True else: From 2f0ed25b7ac31334aee49e3c4a16c33094cb8b90 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 3 Jul 2023 15:42:49 +0200 Subject: [PATCH 3/4] Fix bug in Dimshuffles created by Elemwise --- pytensor/tensor/elemwise.py | 7 +++++-- tests/tensor/test_elemwise.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 6d19579030..0687d30f10 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -130,6 +130,10 @@ def __init__(self, input_broadcastable, new_order): super().__init__([self.c_func_file], self.c_func_name) self.input_broadcastable = tuple(input_broadcastable) + if not all(isinstance(bs, (bool, np.bool_)) for bs in self.input_broadcastable): + raise ValueError( + f"input_broadcastable must be boolean, {self.input_broadcastable}" + ) self.new_order = tuple(new_order) self.inplace = True @@ -411,10 +415,9 @@ def get_output_info(self, dim_shuffle, *inputs): if not difference: args.append(input) else: - # TODO: use LComplete instead args.append( dim_shuffle( - tuple(1 if s == 1 else None for s in input.type.shape), + input.type.broadcastable, ["x"] * difference + list(range(length)), )(input) ) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 40e7db879c..fa820f062a 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -188,6 +188,12 @@ def test_static_shape(self): y = x.dimshuffle([0, 1, "x"]) assert y.type.shape == (1, 2, 1) + def test_valid_input_broadcastable(self): + assert DimShuffle([True, False], (1, 0)).input_broadcastable == (True, False) + + with pytest.raises(ValueError, match="input_broadcastable must be boolean"): + DimShuffle([None, None], (1, 0)) + class TestBroadcast: # this is to allow other types to reuse this class to test their ops From 250fb200e2653b9ad02b8181544e32c350768f92 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 3 Jul 2023 19:03:41 +0200 Subject: [PATCH 4/4] Fix type check in local_pow_specialize --- pytensor/tensor/rewriting/math.py | 5 ++++- tests/tensor/rewriting/test_math.py | 14 +++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b0d124f1c8..46b98a1575 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -2072,7 +2072,10 @@ def local_pow_specialize(fgraph, node): rval = [reciprocal(sqr(xsym))] if rval: rval[0] = cast(rval[0], odtype) - assert rval[0].type == node.outputs[0].type, (rval, node.outputs) + assert rval[0].type.is_super(node.outputs[0].type), ( + rval[0].type, + node.outputs[0].type, + ) return rval else: return False diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index f69879a51d..4e9d143a8e 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -96,7 +96,7 @@ perform_sigm_times_exp, simplify_mul, ) -from pytensor.tensor.shape import Reshape, Shape_i +from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape from pytensor.tensor.type import ( TensorType, cmatrix, @@ -1671,6 +1671,18 @@ def test_local_pow_specialize(): assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal) utt.assert_allclose(f(val_no0), val_no0 ** (-0.5)) + twos = np.full(shape=(10,), fill_value=2.0).astype(config.floatX) + f = function([v], v**twos, mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 2 + # Depending on the mode the SpecifyShape is lifted or not + if topo[0].op == sqr: + assert isinstance(topo[1].op, SpecifyShape) + else: + assert isinstance(topo[0].op, SpecifyShape) + assert topo[1].op == sqr + utt.assert_allclose(f(val), val**twos) + def test_local_pow_specialize_device_more_aggressive_on_cpu(): mode = config.mode