diff --git a/pymc3/distributions/discrete.py b/pymc3/distributions/discrete.py index f0dd06d52f..86db441ae1 100644 --- a/pymc3/distributions/discrete.py +++ b/pymc3/distributions/discrete.py @@ -712,6 +712,7 @@ def __init__(self, p, *args, **kwargs): def random(self, point=None, size=None): p, k = draw_values([self.p, self.k], point=point, size=size) + return generate_samples(random_choice, p=p, broadcast_shape=p.shape[:-1] or (1,), diff --git a/pymc3/distributions/dist_math.py b/pymc3/distributions/dist_math.py index bbe7ddbb9a..ec87d86573 100644 --- a/pymc3/distributions/dist_math.py +++ b/pymc3/distributions/dist_math.py @@ -302,7 +302,8 @@ def random_choice(*args, **kwargs): k = p.shape[-1] if p.ndim > 1: - samples = np.row_stack([np.random.choice(k, p=p_) for p_ in p]) + # If a 2d vector of probabilities is passed return a sample for each row of categorical probability + samples = np.array([np.random.choice(k, p=p_) for p_ in p]) else: samples = np.random.choice(k, p=p, size=size) return samples diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index af87785db5..b249b8d836 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -194,8 +194,7 @@ def test_broadcast_shape(self, size): @pytest.mark.parametrize('shape', [(), (1,), (1, 1), (1, 2), (10, 10, 1), (10, 10, 2)], ids=str) def test_different_shapes_and_sample_sizes(self, shape): prefix = self.distribution.__name__ - expected = [] - actual = [] + rv = self.get_random_variable(shape, name='%s_%s' % (prefix, shape)) for size in (None, 1, 5, (4, 5)): if size is None: @@ -402,6 +401,11 @@ class TestCategorical(BaseTestCases.BaseTestCase): def get_random_variable(self, shape, with_vector_params=False, **kwargs): # don't transform categories return super(TestCategorical, self).get_random_variable(shape, with_vector_params=False, **kwargs) + def test_probability_vector_shape(self): + """Check that if a 2d array of probabilities are passed to categorical correct shape is returned""" + p = np.ones((10, 5)) + assert pm.Categorical.dist(p=p).random().shape == (10,) + class TestScalarParameterSamples(SeededTest): def test_bounded(self):