Skip to content

Deduplicate random choice #3084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings

from pymc3.util import get_variable_name
from .dist_math import bound, factln, binomln, betaln, logpow
from .dist_math import bound, factln, binomln, betaln, logpow, random_choice
from .distribution import Discrete, draw_values, generate_samples
from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp

Expand Down Expand Up @@ -710,19 +710,9 @@ def __init__(self, p, *args, **kwargs):
self.p = (p.T / tt.sum(p, -1)).T
self.mode = tt.argmax(p)

def _random(self, k, p, size=None):
if len(p.shape) > 1:
return np.asarray(
[np.random.choice(k, p=pp, size=size)
for pp in p]
)
else:
return np.asarray(np.random.choice(k, p=p, size=size))

def random(self, point=None, size=None):
p, k = draw_values([self.p, self.k], point=point, size=size)
return generate_samples(self._random,
k=k,
return generate_samples(random_choice,
p=p,
broadcast_shape=p.shape[:-1] or (1,),
dist_shape=self.shape,
Expand Down
27 changes: 27 additions & 0 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,30 @@ def impl(self, x):


i0e = I0e(upgrade_to_float, name='i0e')


def random_choice(*args, **kwargs):
"""Return draws from a categorial probability functions

Args:
p: array
Probability of each class
size: int
Number of draws to return
k: int
Number of bins

Returns:
random sample: array

"""
p = kwargs.pop('p')
size = kwargs.pop('size')
k = p.shape[-1]

if p.ndim > 1:
samples = np.row_stack([np.random.choice(k, p=p_) for p_ in p])
else:
samples = np.random.choice(k, p=p, size=size)
return samples

1 change: 0 additions & 1 deletion pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def random(self, *args, **kwargs):
"Define a custom random method and pass it as kwarg random")



def draw_values(params, point=None, size=None):
"""
Draw (fix) parameter values. Handles a number of cases:
Expand Down
18 changes: 6 additions & 12 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pymc3.util import get_variable_name
from ..math import logsumexp
from .dist_math import bound
from .dist_math import bound, random_choice
from .distribution import Discrete, Distribution, draw_values, generate_samples
from .continuous import get_tau_sd, Normal

Expand Down Expand Up @@ -147,24 +147,18 @@ def logp(self, value):
broadcast_conditions=False)

def random(self, point=None, size=None):
def random_choice(*args, **kwargs):
w = kwargs.pop('w')
w /= w.sum(axis=-1, keepdims=True)
k = w.shape[-1]

if w.ndim > 1:
return np.row_stack([np.random.choice(k, p=w_) for w_ in w])
else:
return np.random.choice(k, p=w, *args, **kwargs)

w = draw_values([self.w], point=point)[0]
comp_tmp = self._comp_samples(point=point, size=None)
if np.asarray(self.shape).size == 0:
distshape = np.asarray(np.broadcast(w, comp_tmp).shape)[..., :-1]
else:
distshape = np.asarray(self.shape)

# Normalize inputs
w /= w.sum(axis=-1, keepdims=True)

w_samples = generate_samples(random_choice,
w=w,
p=w,
broadcast_shape=w.shape[:-1] or (1,),
dist_shape=distshape,
size=size).squeeze()
Expand Down