From 9481b20eeb13bd6ca984b3ec4f06fd138f892668 Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Tue, 26 Jun 2018 07:02:10 -0700 Subject: [PATCH 1/2] Remove duplication of categorical random method --- pymc3/distributions/discrete.py | 14 ++------------ pymc3/distributions/dist_math.py | 26 ++++++++++++++++++++++++++ pymc3/distributions/distribution.py | 1 - pymc3/distributions/mixture.py | 18 ++++++------------ 4 files changed, 34 insertions(+), 25 deletions(-) diff --git a/pymc3/distributions/discrete.py b/pymc3/distributions/discrete.py index 801386a097..f0dd06d52f 100644 --- a/pymc3/distributions/discrete.py +++ b/pymc3/distributions/discrete.py @@ -5,7 +5,7 @@ import warnings from pymc3.util import get_variable_name -from .dist_math import bound, factln, binomln, betaln, logpow +from .dist_math import bound, factln, binomln, betaln, logpow, random_choice from .distribution import Discrete, draw_values, generate_samples from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp @@ -710,19 +710,9 @@ def __init__(self, p, *args, **kwargs): self.p = (p.T / tt.sum(p, -1)).T self.mode = tt.argmax(p) - def _random(self, k, p, size=None): - if len(p.shape) > 1: - return np.asarray( - [np.random.choice(k, p=pp, size=size) - for pp in p] - ) - else: - return np.asarray(np.random.choice(k, p=p, size=size)) - def random(self, point=None, size=None): p, k = draw_values([self.p, self.k], point=point, size=size) - return generate_samples(self._random, - k=k, + return generate_samples(random_choice, p=p, broadcast_shape=p.shape[:-1] or (1,), dist_shape=self.shape, diff --git a/pymc3/distributions/dist_math.py b/pymc3/distributions/dist_math.py index 7671f5b7fa..d3413aad0c 100644 --- a/pymc3/distributions/dist_math.py +++ b/pymc3/distributions/dist_math.py @@ -280,3 +280,29 @@ def impl(self, x): i0e = I0e(upgrade_to_float, name='i0e') + + +def random_choice(*args, **kwargs): + """Return draws from a categorial probability functions + + Args: + p: array + Probability of each class + size: int + Number of draws to return + k: int + Number of bins + + Returns: + random sample: array + + """ + p = kwargs.pop('p') + size = kwargs.pop('size') + k = p.shape[-1] + + if p.ndim > 1: + samples = np.row_stack([np.random.choice(k, p=p_) for p_ in p]) + else: + samples = np.random.choice(k, p=p, size=size) + return samples \ No newline at end of file diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index c5b8040ecb..63a0cea363 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -214,7 +214,6 @@ def random(self, *args, **kwargs): "Define a custom random method and pass it as kwarg random") - def draw_values(params, point=None, size=None): """ Draw (fix) parameter values. Handles a number of cases: diff --git a/pymc3/distributions/mixture.py b/pymc3/distributions/mixture.py index 1bdc3e8e9d..38f34d6c0a 100644 --- a/pymc3/distributions/mixture.py +++ b/pymc3/distributions/mixture.py @@ -3,7 +3,7 @@ from pymc3.util import get_variable_name from ..math import logsumexp -from .dist_math import bound +from .dist_math import bound, random_choice from .distribution import Discrete, Distribution, draw_values, generate_samples from .continuous import get_tau_sd, Normal @@ -147,24 +147,18 @@ def logp(self, value): broadcast_conditions=False) def random(self, point=None, size=None): - def random_choice(*args, **kwargs): - w = kwargs.pop('w') - w /= w.sum(axis=-1, keepdims=True) - k = w.shape[-1] - - if w.ndim > 1: - return np.row_stack([np.random.choice(k, p=w_) for w_ in w]) - else: - return np.random.choice(k, p=w, *args, **kwargs) - w = draw_values([self.w], point=point)[0] comp_tmp = self._comp_samples(point=point, size=None) if np.asarray(self.shape).size == 0: distshape = np.asarray(np.broadcast(w, comp_tmp).shape)[..., :-1] else: distshape = np.asarray(self.shape) + + # Normalize inputs + w /= w.sum(axis=-1, keepdims=True) + w_samples = generate_samples(random_choice, - w=w, + p=w, broadcast_shape=w.shape[:-1] or (1,), dist_shape=distshape, size=size).squeeze() From 217b70e6351a84aba19fd42f5ba5002d025254cc Mon Sep 17 00:00:00 2001 From: Ravin Kumar Date: Wed, 27 Jun 2018 21:08:09 -0700 Subject: [PATCH 2/2] Add newline to dist_math --- pymc3/distributions/dist_math.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc3/distributions/dist_math.py b/pymc3/distributions/dist_math.py index d3413aad0c..bbe7ddbb9a 100644 --- a/pymc3/distributions/dist_math.py +++ b/pymc3/distributions/dist_math.py @@ -305,4 +305,5 @@ def random_choice(*args, **kwargs): samples = np.row_stack([np.random.choice(k, p=p_) for p_ in p]) else: samples = np.random.choice(k, p=p, size=size) - return samples \ No newline at end of file + return samples +