-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Default moment
for CustomDist
provided with a dist
function
#6873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 36 commits
030824e
cdbe6f8
af82efc
c12da7e
b570fb6
ea22848
06c2646
511f0f6
9a4a801
10ba63d
f078665
cee024d
6e84d62
aa8fce9
716f1f8
62b3c59
875e2ed
e1b0f66
3c30352
c017195
6cb03f1
1b60994
0656987
155e44b
097a057
2647993
a0ae812
ef7a7a0
3004799
6b7c11b
9db17f6
b3548cb
20e066c
db49e97
80c6b02
ba00b38
a3c9f14
ed5a3c7
9ec33c6
d5899d4
9b3c43d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,14 +25,16 @@ | |
|
||
from pytensor import tensor as pt | ||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.graph import FunctionGraph, node_rewriter | ||
from pytensor.graph.basic import Node, Variable | ||
from pytensor.graph.replace import clone_replace | ||
from pytensor.graph.rewriting.basic import in2out | ||
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter | ||
from pytensor.graph.basic import Node, Variable, io_toposort | ||
from pytensor.graph.features import ReplaceValidate | ||
from pytensor.graph.rewriting.basic import GraphRewriter, in2out | ||
from pytensor.graph.utils import MetaType | ||
from pytensor.scan.op import Scan | ||
from pytensor.tensor.basic import as_tensor_variable | ||
from pytensor.tensor.random.op import RandomVariable | ||
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift | ||
from pytensor.tensor.random.type import RandomGeneratorType, RandomType | ||
from pytensor.tensor.random.utils import normalize_size_param | ||
from pytensor.tensor.rewriting.shape import ShapeFeature | ||
from pytensor.tensor.variable import TensorVariable | ||
|
@@ -83,6 +85,59 @@ | |
PLATFORM = sys.platform | ||
|
||
|
||
class MomentRewrite(GraphRewriter): | ||
def rewrite_moment_scan_node(self, node): | ||
if not isinstance(node.op, Scan): | ||
return | ||
|
||
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs | ||
op = node.op | ||
|
||
local_fgraph_topo = io_toposort(node_inputs, node_outputs) | ||
|
||
replace_with_moment = [] | ||
to_replace_set = set() | ||
|
||
for nd in local_fgraph_topo: | ||
if nd not in to_replace_set and isinstance( | ||
nd.op, (RandomVariable, SymbolicRandomVariable) | ||
): | ||
replace_with_moment.append(nd.out) | ||
to_replace_set.add(nd) | ||
givens = {} | ||
if len(replace_with_moment) > 0: | ||
for item in replace_with_moment: | ||
givens[item] = moment(item) | ||
else: | ||
return | ||
op_outs = clone_replace(node_outputs, replace=givens) | ||
|
||
nwScan = Scan( | ||
node_inputs, | ||
op_outs, | ||
op.info, | ||
mode=op.mode, | ||
profile=op.profile, | ||
truncate_gradient=op.truncate_gradient, | ||
name=op.name, | ||
allow_gc=op.allow_gc, | ||
) | ||
nw_node = nwScan(*(node.inputs), return_list=True)[0].owner | ||
return nw_node | ||
|
||
def add_requirements(self, fgraph): | ||
fgraph.attach_feature(ReplaceValidate()) | ||
|
||
def apply(self, fgraph): | ||
for node in fgraph.toposort(): | ||
if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): | ||
fgraph.replace(node.out, moment(node.out)) | ||
elif isinstance(node.op, Scan): | ||
new_node = self.rewrite_moment_scan_node(node) | ||
if new_node is not None: | ||
fgraph.replace_all(tuple(zip(node.outputs, new_node.outputs))) | ||
|
||
|
||
class _Unpickling: | ||
pass | ||
|
||
|
@@ -601,6 +656,20 @@ def update(self, node: Node): | |
return updates | ||
|
||
|
||
def dist_moment(rv, *args): | ||
aerubanov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
node = rv.owner | ||
op = node.op | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rv_out_idx = node.outputs.index(rv) | ||
|
||
fgraph = op.fgraph.clone() | ||
replace_moments = MomentRewrite() | ||
replace_moments.rewrite(fgraph) | ||
# Replace dummy inner inputs by outer inputs | ||
fgraph.replace_all(tuple(zip(op.inner_inputs, node.inputs)), import_missing=True) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
moment = fgraph.outputs[rv_out_idx] | ||
return moment | ||
|
||
|
||
class _CustomSymbolicDist(Distribution): | ||
rv_type = CustomSymbolicDistRV | ||
|
||
|
@@ -623,12 +692,7 @@ def dist( | |
logcdf = default_not_implemented(class_name, "logcdf") | ||
|
||
if moment is None: | ||
moment = functools.partial( | ||
default_moment, | ||
rv_name=class_name, | ||
has_fallback=True, | ||
ndim_supp=ndim_supp, | ||
) | ||
moment = dist_moment | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this if statement (see other comment about default) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ricardoV94 Hm, I think need this condition to avoid overriding moment provided by user, how can we avoid it without condition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More specifically, If I try to remove if statement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two levels. We dispatch the general moment to the parent base class. Whenever a new subclass is created here and the user provided a moment we register it on the subclass (so it has preference). If the user didn't provide anything, we don't register and the parent class one will be used. |
||
|
||
return super().dist( | ||
dist_params, | ||
|
@@ -685,9 +749,19 @@ def custom_dist_logp(op, values, size, *params, **kwargs): | |
def custom_dist_logcdf(op, value, size, *params, **kwargs): | ||
return logcdf(value, *params[: len(dist_params)]) | ||
|
||
@_moment.register(rv_type) | ||
def custom_dist_get_moment(op, rv, size, *params): | ||
return moment(rv, size, *params[: len(params)]) | ||
if moment is not None: | ||
|
||
@_moment.register(rv_type) | ||
def custom_dist_get_moment(op, rv, size, *params): | ||
return moment( | ||
rv, | ||
size, | ||
*[ | ||
p | ||
for p in params | ||
if not isinstance(p.type, (RandomType, RandomGeneratorType)) | ||
], | ||
) | ||
|
||
@_change_dist_size.register(rv_type) | ||
def change_custom_symbolic_dist_size(op, rv, new_size, expand): | ||
|
Uh oh!
There was an error while loading. Please reload this page.