Skip to content

Commit 4bf18b2

Browse files
committed
Abstract special behavior of NUTS divergences in ProgressBar
Every step sampler can now decide whether sampling is failing or not by setting "failing" in the returned update dict
1 parent 3b3279c commit 4bf18b2

File tree

5 files changed

+85
-39
lines changed

5 files changed

+85
-39
lines changed

pymc/progress_bar.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -168,28 +168,28 @@ def call_column(column, task):
168168
return table
169169

170170

171-
class DivergenceBarColumn(BarColumn):
172-
"""Rich colorbar that changes color when a chain has detected a divergence."""
171+
class RecolorOnFailureBarColumn(BarColumn):
172+
"""Rich colorbar that changes color when a chain has detected a failure."""
173173

174-
def __init__(self, *args, diverging_color="red", **kwargs):
174+
def __init__(self, *args, failing_color="red", **kwargs):
175175
from matplotlib.colors import to_rgb
176176

177-
self.diverging_color = diverging_color
178-
self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)]
177+
self.failing_color = failing_color
178+
self.failing_rgb = [int(x * 255) for x in to_rgb(self.failing_color)]
179179

180180
super().__init__(*args, **kwargs)
181181

182-
self.non_diverging_style = self.complete_style
183-
self.non_diverging_finished_style = self.finished_style
182+
self.default_complete_style = self.complete_style
183+
self.default_finished_style = self.finished_style
184184

185185
def callbacks(self, task: "Task"):
186-
divergences = task.fields.get("divergences", 0)
187-
if isinstance(divergences, float | int) and divergences > 0:
188-
self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb))
189-
self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb))
186+
if task.fields["failing"]:
187+
self.complete_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb))
188+
self.finished_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb))
190189
else:
191-
self.complete_style = self.non_diverging_style
192-
self.finished_style = self.non_diverging_finished_style
190+
# Recovered from failing yay
191+
self.complete_style = self.default_complete_style
192+
self.finished_style = self.default_finished_style
193193

194194

195195
class ProgressBarManager:
@@ -284,7 +284,6 @@ def __init__(
284284
self.update_stats_functions = step_method._make_progressbar_update_functions()
285285

286286
self._show_progress = show_progress
287-
self.divergences = 0
288287
self.completed_draws = 0
289288
self.total_draws = draws + tune
290289
self.desc = "Sampling chain"
@@ -311,6 +310,7 @@ def _initialize_tasks(self):
311310
chain_idx=0,
312311
sampling_speed=0,
313312
speed_unit="draws/s",
313+
failing=False,
314314
**{stat: value[0] for stat, value in self.progress_stats.items()},
315315
)
316316
]
@@ -325,6 +325,7 @@ def _initialize_tasks(self):
325325
chain_idx=chain_idx,
326326
sampling_speed=0,
327327
speed_unit="draws/s",
328+
failing=False,
328329
**{stat: value[chain_idx] for stat, value in self.progress_stats.items()},
329330
)
330331
for chain_idx in range(self.chains)
@@ -354,42 +355,43 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
354355
elapsed = self._progress.tasks[chain_idx].elapsed
355356
speed, unit = self.compute_draw_speed(elapsed, draw)
356357

357-
if not tuning and stats and stats[0].get("diverging"):
358-
self.divergences += 1
358+
failing = False
359+
all_step_stats = {}
359360

360-
if self.full_stats:
361-
# TODO: Index by chain already?
362-
chain_progress_stats = [
363-
update_states_fn(step_stats)
364-
for update_states_fn, step_stats in zip(
365-
self.update_stats_functions, stats, strict=True
366-
)
367-
]
368-
all_step_stats = {}
369-
for step_stats in chain_progress_stats:
370-
for key, val in step_stats.items():
371-
if key in all_step_stats:
372-
# TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now
373-
continue
374-
else:
375-
all_step_stats[key] = val
376-
377-
else:
378-
all_step_stats = {}
361+
chain_progress_stats = [
362+
update_stats_fn(step_stats)
363+
for update_stats_fn, step_stats in zip(self.update_stats_functions, stats, strict=True)
364+
]
365+
for step_stats in chain_progress_stats:
366+
for key, val in step_stats.items():
367+
if key == "failing":
368+
failing |= val
369+
continue
370+
if not self.full_stats:
371+
# Only care about the "failing" flag
372+
continue
373+
374+
if key in all_step_stats:
375+
# TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now
376+
continue
377+
else:
378+
all_step_stats[key] = val
379379

380380
self._progress.update(
381381
self.tasks[chain_idx],
382382
completed=draw,
383383
draws=draw,
384384
sampling_speed=speed,
385385
speed_unit=unit,
386+
failing=failing,
386387
**all_step_stats,
387388
)
388389

389390
if is_last:
390391
self._progress.update(
391392
self.tasks[chain_idx],
392393
draws=draw + 1 if not self.combined_progress else draw,
394+
failing=failing,
393395
**all_step_stats,
394396
refresh=True,
395397
)
@@ -410,9 +412,9 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
410412
]
411413

412414
return CustomProgress(
413-
DivergenceBarColumn(
415+
RecolorOnFailureBarColumn(
414416
table_column=Column("Progress", ratio=2),
415-
diverging_color="tab:red",
417+
failing_color="tab:red",
416418
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
417419
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue
418420
),

pymc/step_methods/hmc/base_hmc.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def __init__(
184184

185185
self._step_rand = step_rand
186186
self._num_divs_sample = 0
187+
self.divergences = 0
187188

188189
@abstractmethod
189190
def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
@@ -266,11 +267,15 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
266267
divergence_info=info_store,
267268
)
268269

270+
diverging = bool(hmc_step.divergence_info)
271+
if not self.tune:
272+
self.divergences += diverging
269273
self.iter_count += 1
270274

271275
stats: dict[str, Any] = {
272276
"tune": self.tune,
273-
"diverging": bool(hmc_step.divergence_info),
277+
"diverging": diverging,
278+
"divergences": self.divergences,
274279
"perf_counter_diff": perf_end - perf_start,
275280
"process_time_diff": process_end - process_start,
276281
"perf_counter_start": perf_start,
@@ -288,6 +293,8 @@ def reset_tuning(self, start=None):
288293
self.reset(start=None)
289294

290295
def reset(self, start=None):
296+
self.iter_count = 0
297+
self.divergences = 0
291298
self.tune = True
292299
self.potential.reset()
293300

pymc/step_methods/hmc/hmc.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
import numpy as np
2121

22+
from rich.progress import TextColumn
23+
from rich.table import Column
24+
2225
from pymc.stats.convergence import SamplerWarning
2326
from pymc.step_methods.compound import Competence
2427
from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
@@ -55,6 +58,7 @@ class HamiltonianMC(BaseHMC):
5558
"accept": (np.float64, []),
5659
"diverging": (bool, []),
5760
"energy_error": (np.float64, []),
61+
"divergences": (np.int64, []),
5862
"energy": (np.float64, []),
5963
"path_length": (np.float64, []),
6064
"accepted": (bool, []),
@@ -202,3 +206,32 @@ def competence(var, has_grad):
202206
if var.dtype in discrete_types or not has_grad:
203207
return Competence.INCOMPATIBLE
204208
return Competence.COMPATIBLE
209+
210+
@staticmethod
211+
def _progressbar_config(n_chains=1):
212+
columns = [
213+
TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)),
214+
TextColumn("{task.fields[n_steps]}", table_column=Column("Grad evals", ratio=1)),
215+
]
216+
217+
stats = {
218+
"divergences": [0] * n_chains,
219+
"n_steps": [0] * n_chains,
220+
}
221+
222+
return columns, stats
223+
224+
@staticmethod
225+
def _make_progressbar_update_functions():
226+
def update_stats(stats):
227+
return {
228+
key: stats[key]
229+
for key in (
230+
"divergences",
231+
"n_steps",
232+
)
233+
} | {
234+
"failing": stats["divergences"] > 0,
235+
}
236+
237+
return (update_stats,)

pymc/step_methods/hmc/nuts.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class NUTS(BaseHMC):
115115
"step_size_bar": (np.float64, []),
116116
"tree_size": (np.float64, []),
117117
"diverging": (bool, []),
118+
"divergences": (int, []),
118119
"energy_error": (np.float64, []),
119120
"energy": (np.float64, []),
120121
"max_energy_error": (np.float64, []),
@@ -250,7 +251,9 @@ def _progressbar_config(n_chains=1):
250251
@staticmethod
251252
def _make_progressbar_update_functions():
252253
def update_stats(stats):
253-
return {key: stats[key] for key in ("diverging", "step_size", "tree_size")}
254+
return {key: stats[key] for key in ("divergences", "step_size", "tree_size")} | {
255+
"failing": stats["divergences"] > 0,
256+
}
254257

255258
return (update_stats,)
256259

tests/step_methods/hmc/test_nuts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def test_sampler_stats(self):
148148
expected_stat_names = {
149149
"depth",
150150
"diverging",
151+
"divergences",
151152
"energy",
152153
"energy_error",
153154
"model_logp",

0 commit comments

Comments
 (0)