From 0f4fc07272e69a43eca00eb468cd5954812b0f79 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Sun, 15 Jul 2018 09:18:11 -0700 Subject: [PATCH] Parametrize shape and size tests --- pymc3/tests/test_distributions_random.py | 106 +++++++++++------------ 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index bf573081a3..af87785db5 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -145,73 +145,73 @@ def sample_random_variable(random_variable, size): except AttributeError: return random_variable.distribution.random(size=size) - def test_scalar_parameter_shape(self): + @pytest.mark.parametrize('size', [None, 5, (4, 5)], ids=str) + def test_scalar_parameter_shape(self, size): rv = self.get_random_variable(None) - for size in (None, 5, (4, 5)): - if size is None: - expected = 1, - else: - expected = np.atleast_1d(size).tolist() - actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape - assert tuple(expected) == actual + if size is None: + expected = 1, + else: + expected = np.atleast_1d(size).tolist() + actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape + assert tuple(expected) == actual - def test_scalar_shape(self): + @pytest.mark.parametrize('size', [None, 5, (4, 5)], ids=str) + def test_scalar_shape(self, size): shape = 10 rv = self.get_random_variable(shape) - for size in (None, 5, (4, 5)): - if size is None: - expected = [] - else: - expected = np.atleast_1d(size).tolist() - expected.append(shape) - actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape - assert tuple(expected) == actual - def test_parameters_1d_shape(self): + if size is None: + expected = [] + else: + expected = np.atleast_1d(size).tolist() + expected.append(shape) + actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape + assert tuple(expected) == actual + + @pytest.mark.parametrize('size', [None, 5, (4, 5)], ids=str) + def test_parameters_1d_shape(self, size): rv = self.get_random_variable(self.shape, with_vector_params=True) - for size in (None, 5, (4, 5)): - if size is None: - expected = [] - else: - expected = np.atleast_1d(size).tolist() - expected.append(self.shape) - actual = self.sample_random_variable(rv, size).shape - assert tuple(expected) == actual + if size is None: + expected = [] + else: + expected = np.atleast_1d(size).tolist() + expected.append(self.shape) + actual = self.sample_random_variable(rv, size).shape + assert tuple(expected) == actual - def test_broadcast_shape(self): + @pytest.mark.parametrize('size', [None, 5, (4, 5)], ids=str) + def test_broadcast_shape(self, size): broadcast_shape = (2 * self.shape, self.shape) rv = self.get_random_variable(broadcast_shape, with_vector_params=True) - for size in (None, 5, (4, 5)): - if size is None: - expected = [] - else: - expected = np.atleast_1d(size).tolist() - expected.extend(broadcast_shape) - actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape - assert tuple(expected) == actual + if size is None: + expected = [] + else: + expected = np.atleast_1d(size).tolist() + expected.extend(broadcast_shape) + actual = np.atleast_1d(self.sample_random_variable(rv, size)).shape + assert tuple(expected) == actual - def test_different_shapes_and_sample_sizes(self): - shapes = [(), (1,), (1, 1), (1, 2), (10, 10, 1), (10, 10, 2)] + @pytest.mark.parametrize('shape', [(), (1,), (1, 1), (1, 2), (10, 10, 1), (10, 10, 2)], ids=str) + def test_different_shapes_and_sample_sizes(self, shape): prefix = self.distribution.__name__ expected = [] actual = [] - for shape in shapes: - rv = self.get_random_variable(shape, name='%s_%s' % (prefix, shape)) - for size in (None, 1, 5, (4, 5)): - if size is None: + rv = self.get_random_variable(shape, name='%s_%s' % (prefix, shape)) + for size in (None, 1, 5, (4, 5)): + if size is None: + s = [] + else: + try: + s = list(size) + except TypeError: + s = [size] + if s == [1]: s = [] - else: - try: - s = list(size) - except TypeError: - s = [size] - if s == [1]: - s = [] - if shape not in ((), (1,)): - s.extend(shape) - e = tuple(s) - a = self.sample_random_variable(rv, size).shape - assert e == a + if shape not in ((), (1,)): + s.extend(shape) + e = tuple(s) + a = self.sample_random_variable(rv, size).shape + assert e == a class TestNormal(BaseTestCases.BaseTestCase):