diff --git a/pymc3/distributions/discrete.py b/pymc3/distributions/discrete.py index c72af9b34a..e3462e6b8c 100644 --- a/pymc3/distributions/discrete.py +++ b/pymc3/distributions/discrete.py @@ -48,7 +48,7 @@ def __init__(self, n, p, *args, **kwargs): self.mode = tt.cast(tround(n * p), self.dtype) def random(self, point=None, size=None, repeat=None): - n, p = draw_values([self.n, self.p], point=point) + n, p = draw_values([self.n, self.p], point=point, size=size) return generate_samples(stats.binom.rvs, n=n, p=p, dist_shape=self.shape, size=size) diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index c0627254fb..0e7ba4f3c9 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -186,7 +186,7 @@ def __init__(self, logp, shape=(), dtype=None, testval=0, *args, **kwargs): self.logp = logp -def draw_values(params, point=None): +def draw_values(params, point=None, size=None): """ Draw (fix) parameter values. Handles a number of cases: @@ -215,10 +215,10 @@ def draw_values(params, point=None): for name, node in named_nodes.items(): if not isinstance(node, (tt.sharedvar.SharedVariable, tt.TensorConstant)): - givens[name] = (node, _draw_value(node, point=point)) + givens[name] = (node, _draw_value(node, point=point, size=size)) values = [] for param in params: - values.append(_draw_value(param, point=point, givens=givens.values())) + values.append(_draw_value(param, point=point, givens=givens.values(), size=size)) return values @@ -247,7 +247,7 @@ def _compile_theano_function(param, vars, givens=None): allow_input_downcast=True) -def _draw_value(param, point=None, givens=None): +def _draw_value(param, point=None, givens=None, size=None): """Draw a random value from a distribution or return a constant. Parameters @@ -263,6 +263,8 @@ def _draw_value(param, point=None, givens=None): givens : dict, optional A dictionary from theano variables to their values. These values are used to evaluate `param` if it is a theano variable. + size: int, optional + Number of sampling """ if isinstance(param, numbers.Number): return param @@ -276,7 +278,7 @@ def _draw_value(param, point=None, givens=None): if point and hasattr(param, 'model') and param.name in point: return point[param.name] elif hasattr(param, 'random') and param.random is not None: - return param.random(point=point, size=None) + return param.random(point=point, size=size).mean(axis=0) else: if givens: variables, values = list(zip(*givens))