diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 8b15ad4dda..c0ddf7e894 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -555,12 +555,15 @@ def test_outputs_consistency(self): def test_explicit_input_from_constant(self): x = pt.dscalar("x") - y = constant(1.0, name="y") + y = constant(1.0, dtype=x.type.dtype, name="y") test_ofg = OpFromGraph([x, y], [x + y]) out = test_ofg(x, y) assert out.eval({x: 5}) == 6 + out = test_ofg(x, x) + assert out.eval({x: 5}) == 10 + def test_explicit_input_from_shared(self): x = pt.dscalar("x") y = shared(1.0, name="y") @@ -576,7 +579,10 @@ def test_explicit_input_from_shared(self): out = test_ofg(x, y) assert out.eval({x: 5}) == 6 y.set_value(2.0) - assert out.eval({x: 6}) + assert out.eval({x: 6}) == 8 + + out = test_ofg(y, y) + assert out.eval() == 4 @config.change_flags(floatX="float64")