Skip to content

Commit ce59552

Browse files
committed
Do not monkey-patch Ipython pretty representation on model variables
1 parent 8ebb61e commit ce59552

File tree

4 files changed

+42
-100
lines changed

4 files changed

+42
-100
lines changed

pymc/distributions/distribution.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import contextvars
15-
import functools
1615
import re
1716
import sys
18-
import types
1917
import warnings
2018

2119
from abc import ABCMeta
@@ -53,7 +51,6 @@
5351
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
5452
from pymc.logprob.basic import logp
5553
from pymc.logprob.rewriting import logprob_rewrites_db
56-
from pymc.printing import str_for_dist
5754
from pymc.pytensorf import (
5855
collect_default_updates_inner_fgraph,
5956
constant_fold,
@@ -514,12 +511,6 @@ def __new__(
514511
default_transform=default_transform,
515512
initval=initval,
516513
)
517-
518-
# add in pretty-printing support
519-
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
520-
rv_out._repr_latex_ = types.MethodType(
521-
functools.partial(str_for_dist, formatting="latex"), rv_out
522-
)
523514
return rv_out
524515

525516
@classmethod

pymc/model/core.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import functools
1717
import sys
1818
import threading
19-
import types
2019
import warnings
2120

2221
from collections.abc import Iterable, Sequence
@@ -491,13 +490,6 @@ def __init__(
491490
self._dim_lengths = {}
492491
self.add_coords(coords)
493492

494-
from pymc.printing import str_for_model
495-
496-
self.str_repr = types.MethodType(str_for_model, self)
497-
self._repr_latex_ = types.MethodType(
498-
functools.partial(str_for_model, formatting="latex"), self
499-
)
500-
501493
@classmethod
502494
def get_context(
503495
cls, error_if_none: bool = True, allow_block_model_access: bool = False
@@ -2014,6 +2006,19 @@ def to_graphviz(
20142006
dpi=dpi,
20152007
)
20162008

2009+
def _repr_pretty_(self, p, cycle):
2010+
from pymc.printing import str_for_model
2011+
2012+
output = str_for_model(self)
2013+
# Find newlines and replace them with p.break_()
2014+
# (see IPython.lib.pretty._repr_pprint)
2015+
lines = output.splitlines()
2016+
with p.group():
2017+
for idx, output_line in enumerate(lines):
2018+
if idx:
2019+
p.break_()
2020+
p.text(output_line)
2021+
20172022

20182023
class BlockModelAccess(Model):
20192024
"""Can be used to prevent user access to Model contexts."""
@@ -2240,19 +2245,6 @@ def Deterministic(name, var, model=None, dims=None):
22402245
var = var.copy(model.name_for(name))
22412246
model.deterministics.append(var)
22422247
model.add_named_variable(var, dims)
2243-
2244-
from pymc.printing import str_for_potential_or_deterministic
2245-
2246-
var.str_repr = types.MethodType(
2247-
functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
2248-
)
2249-
var._repr_latex_ = types.MethodType(
2250-
functools.partial(
2251-
str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
2252-
),
2253-
var,
2254-
)
2255-
22562248
return var
22572249

22582250

@@ -2365,16 +2357,4 @@ def normal_logp(value, mu, sigma):
23652357
model.potentials.append(var)
23662358
model.add_named_variable(var, dims)
23672359

2368-
from pymc.printing import str_for_potential_or_deterministic
2369-
2370-
var.str_repr = types.MethodType(
2371-
functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
2372-
)
2373-
var._repr_latex_ = types.MethodType(
2374-
functools.partial(
2375-
str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
2376-
),
2377-
var,
2378-
)
2379-
23802360
return var

pymc/printing.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535

3636

3737
def str_for_dist(
38-
dist: TensorVariable, formatting: str = "plain", include_params: bool = True
38+
dist: TensorVariable,
39+
formatting: str = "plain",
40+
include_params: bool = True,
41+
model: Model | None = None,
3942
) -> str:
4043
"""Make a human-readable string representation of a Distribution in a model.
4144
@@ -47,12 +50,12 @@ def str_for_dist(
4750
dist.owner.op, "extended_signature", None
4851
):
4952
dist_args = [
50-
_str_for_input_var(x, formatting=formatting)
53+
_str_for_input_var(x, formatting=formatting, model=model)
5154
for x in dist.owner.op.dist_params(dist.owner)
5255
]
5356
else:
5457
dist_args = [
55-
_str_for_input_var(x, formatting=formatting)
58+
_str_for_input_var(x, formatting=formatting, model=model)
5659
for x in dist.owner.inputs
5760
if not isinstance(x.type, RandomType | NoneTypeT)
5861
]
@@ -106,7 +109,7 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool
106109
including parameter values.
107110
"""
108111
# Wrap functions to avoid confusing typecheckers
109-
sfd = partial(str_for_dist, formatting=formatting, include_params=include_params)
112+
sfd = partial(str_for_dist, formatting=formatting, include_params=include_params, model=model)
110113
sfp = partial(
111114
str_for_potential_or_deterministic, formatting=formatting, include_params=include_params
112115
)
@@ -169,18 +172,14 @@ def str_for_potential_or_deterministic(
169172
return rf"{print_name} ~ {dist_name}"
170173

171174

172-
def _str_for_input_var(var: Variable, formatting: str) -> str:
175+
def _str_for_input_var(var: Variable, formatting: str, model: Model | None = None) -> str:
173176
# Avoid circular import
174177
from pymc.distributions.distribution import SymbolicRandomVariable
175178

176179
def _is_potential_or_deterministic(var: Variable) -> bool:
177-
if not hasattr(var, "str_repr"):
178-
return False
179-
try:
180-
return var.str_repr.__func__.func is str_for_potential_or_deterministic
181-
except AttributeError:
182-
# in case other code overrides str_repr, fallback
180+
if model is None:
183181
return False
182+
return var in model.deterministics or var in model.potentials
184183

185184
if isinstance(var, Constant | SharedVariable):
186185
return _str_for_constant(var, formatting)
@@ -190,18 +189,18 @@ def _is_potential_or_deterministic(var: Variable) -> bool:
190189
# show the names for RandomVariables, Deterministics, and Potentials, rather
191190
# than the full expression
192191
assert isinstance(var, TensorVariable)
193-
return _str_for_input_rv(var, formatting)
192+
return _str_for_input_rv(var, formatting, model=model)
194193
elif isinstance(var.owner.op, DimShuffle):
195-
return _str_for_input_var(var.owner.inputs[0], formatting)
194+
return _str_for_input_var(var.owner.inputs[0], formatting, model=model)
196195
else:
197196
return _str_for_expression(var, formatting)
198197

199198

200-
def _str_for_input_rv(var: TensorVariable, formatting: str) -> str:
199+
def _str_for_input_rv(var: TensorVariable, formatting: str, model: Model | None = None) -> str:
201200
_str = (
202201
var.name
203202
if var.name is not None
204-
else str_for_dist(var, formatting=formatting, include_params=True)
203+
else str_for_dist(var, formatting=formatting, include_params=True, model=model)
205204
)
206205
if "latex" in formatting:
207206
return _latex_text_format(_latex_escape(_str.strip("$")))
@@ -277,37 +276,6 @@ def _latex_escape(text: str) -> str:
277276
return text.replace("$", r"\$")
278277

279278

280-
def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
281-
"""Handy plug-in method to instruct IPython-like REPLs to use our str_repr above."""
282-
# we know that our str_repr does not recurse, so we can ignore cycle
283-
try:
284-
if not hasattr(obj, "str_repr"):
285-
raise AttributeError
286-
output = obj.str_repr()
287-
# Find newlines and replace them with p.break_()
288-
# (see IPython.lib.pretty._repr_pprint)
289-
lines = output.splitlines()
290-
with p.group():
291-
for idx, output_line in enumerate(lines):
292-
if idx:
293-
p.break_()
294-
p.text(output_line)
295-
except AttributeError:
296-
# the default fallback option (no str_repr method)
297-
IPython.lib.pretty._repr_pprint(obj, p, cycle)
298-
299-
300-
try:
301-
# register our custom pretty printer in ipython shells
302-
import IPython
303-
304-
IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty)
305-
IPython.lib.pretty.for_type(Model, _default_repr_pretty)
306-
except (ModuleNotFoundError, AttributeError):
307-
# no ipython shell
308-
pass
309-
310-
311279
def _format_underscore(variable: str) -> str:
312280
"""Escapes all unescaped underscores in the variable name for LaTeX representation."""
313281
return re.sub(r"(?<!\\)_", r"\\_", variable)

tests/test_printing.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,17 @@
3030
)
3131
from pymc.math import dot
3232
from pymc.model import Deterministic, Model, Potential
33+
from pymc.printing import str_for_dist, str_for_model
3334
from pymc.pytensorf import floatX
3435

3536

3637
class BaseTestStrAndLatexRepr:
3738
def test__repr_latex_(self):
38-
for distribution, tex in zip(self.distributions, self.expected[("latex", True)]):
39-
assert distribution._repr_latex_() == tex
39+
for model_variable, tex in zip(self.distributions, self.expected[("latex", True)]):
40+
if model_variable in self.model.basic_RVs:
41+
assert str_for_dist(model_variable, formatting="latex", model=self.model) == tex
4042

41-
model_tex = self.model._repr_latex_()
43+
model_tex = str_for_model(self.model, formatting="latex")
4244

4345
# make sure each variable is in the model
4446
for tex in self.expected[("latex", True)]:
@@ -47,10 +49,11 @@ def test__repr_latex_(self):
4749

4850
def test_str_repr(self):
4951
for str_format in self.formats:
50-
for dist, text in zip(self.distributions, self.expected[str_format]):
51-
assert dist.str_repr(*str_format) == text
52+
for model_variable, text in zip(self.distributions, self.expected[str_format]):
53+
if model_variable in self.model.basic_RVs:
54+
assert str_for_dist(model_variable, *str_format, model=self.model) == text
5255

53-
model_text = self.model.str_repr(*str_format)
56+
model_text = str_for_model(self.model, *str_format)
5457
for text in self.expected[str_format]:
5558
if str_format[0] == "latex":
5659
for segment in text.strip("$").split(r"\sim"):
@@ -252,7 +255,7 @@ def test_model_latex_repr_three_levels_model():
252255
"censored_normal", normal_dist, lower=-2.0, upper=2.0, observed=[1, 0, 0.5]
253256
)
254257

255-
latex_repr = censored_model.str_repr(formatting="latex")
258+
latex_repr = str_for_model(censored_model, formatting="latex")
256259
expected = [
257260
"$$",
258261
"\\begin{array}{rcl}",
@@ -270,7 +273,7 @@ def test_model_latex_repr_mixture_model():
270273
w = Dirichlet("w", [1, 1])
271274
mix = Mixture("mix", w=w, comp_dists=[Normal.dist(0.0, 5.0), StudentT.dist(7.0)])
272275

273-
latex_repr = mix_model.str_repr(formatting="latex")
276+
latex_repr = str_for_model(mix_model, formatting="latex")
274277
expected = [
275278
"$$",
276279
"\\begin{array}{rcl}",
@@ -291,15 +294,15 @@ def test_model_repr_variables_without_monkey_patched_repr():
291294
model = Model()
292295
model.register_rv(x, "x")
293296

294-
str_repr = model.str_repr()
297+
str_repr = str_for_model(model)
295298
assert str_repr == "x ~ Normal(0, 1)"
296299

297300

298301
def test_truncated_repr():
299302
with Model() as model:
300303
x = Truncated("x", Gamma.dist(1, 1), lower=0, upper=20)
301304

302-
str_repr = model.str_repr(include_params=False)
305+
str_repr = str_for_model(model, include_params=False)
303306
assert str_repr == "x ~ TruncatedGamma"
304307

305308

@@ -315,7 +318,7 @@ def random(rng, mu, size):
315318
x = CustomDist("x", 0, dist=dist, class_name="CustomDistNormal")
316319
x = CustomDist("y", 0, random=random, class_name="CustomRandomNormal")
317320

318-
str_repr = model.str_repr(include_params=False)
321+
str_repr = str_for_model(model, include_params=False)
319322
assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"])
320323

321324

@@ -333,6 +336,6 @@ def test_latex_escaped_underscore(self):
333336
Ensures that all underscores in model variable names are properly escaped for LaTeX representation
334337
"""
335338
model = self.simple_model()
336-
model_str = model.str_repr(formatting="latex")
339+
model_str = str_for_model(model, formatting="latex")
337340
assert "\\_" in model_str
338341
assert "_" not in model_str.replace("\\_", "")

0 commit comments

Comments
 (0)