-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add model_to_minibatch
transformation to convert all pm.Data
to pm.Minibatch
#7785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,17 +14,21 @@ | |
from collections.abc import Sequence | ||
|
||
from pytensor import Variable, clone_replace | ||
from pytensor.compile import SharedVariable | ||
from pytensor.graph import ancestors | ||
from pytensor.graph.fg import FunctionGraph | ||
|
||
from pymc.data import MinibatchOp | ||
from pymc.data import Minibatch, MinibatchOp | ||
from pymc.model.core import Model | ||
from pymc.model.fgraph import ( | ||
ModelObservedRV, | ||
ModelVar, | ||
extract_dims, | ||
fgraph_from_model, | ||
model_from_fgraph, | ||
model_observed_rv, | ||
) | ||
from pymc.pytensorf import toposort_replace | ||
|
||
ModelVariable = Variable | str | ||
|
||
|
@@ -62,6 +66,47 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l | |
return [model[var] if isinstance(var, str) else var for var in vars_seq] | ||
|
||
|
||
def model_to_minibatch(model: Model, batch_size: int) -> Model: | ||
"""Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs.""" | ||
from pymc.variational.minibatch_rv import create_minibatch_rv | ||
|
||
fgraph, memo = fgraph_from_model(model, inlined_views=True) | ||
|
||
# obs_rvs, data_vars = model.rvs_to_values.items() | ||
|
||
data_vars = [ | ||
memo[datum].owner.inputs[0] | ||
for datum in (model.named_vars[datum_name] for datum_name in model.named_vars) | ||
if isinstance(datum, SharedVariable) | ||
] | ||
|
||
minibatch_vars = Minibatch(*data_vars, batch_size=batch_size) | ||
replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} | ||
assert 0 | ||
# Add total_size to all observed RVs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should only add to those that depend on the minibatch data no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The correct thing would be a dim analysis like we do for MarginaModel to confirm the first dim of the data maps to the first dim of the observed rvs, which is when the rewrite is valid. We may not want to do that, but we should be clear about the assumptions in the docstrings. Example where minibatch rewrite will fail / do the wrong thing, is if you tranpose the data before you used it in the observations. |
||
total_size = data_vars[0].get_value().shape[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. total size can be symbolic I think? |
||
for obs_var in model.observed_RVs: | ||
model_var = memo[obs_var] | ||
var = model_var.owner.inputs[0] | ||
var.name = model_var.name | ||
dims = extract_dims(model_var) | ||
|
||
new_rv = create_minibatch_rv(var, total_size=total_size) | ||
new_rv.name = var.name | ||
|
||
replacements[model_var] = model_observed_rv(new_rv, model.rvs_to_values[obs_var], *dims) | ||
|
||
# old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths | ||
toposort_replace(fgraph, tuple(replacements.items())) | ||
# new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] | ||
|
||
# fgraph = FunctionGraph(outputs=new_outs, clone=False) | ||
# fgraph._coords = old_coords # type: ignore[attr-defined] | ||
# fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] | ||
|
||
return model_from_fgraph(fgraph, mutate_fgraph=True) | ||
|
||
|
||
def remove_minibatched_nodes(model: Model) -> Model: | ||
"""Remove all uses of pm.Minibatch in the Model.""" | ||
fgraph, _ = fgraph_from_model(model) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a
model.data_vars
. You should however allow users to specify which data vars to be minibatched (default to all is fine). Alternatively we could restrict this to models with dims, and the user has to tell us which dim is being minibatched?That makes the graph analysis easier