Skip to content

Commit 46226a8

Browse files
committed
Alternative fix attempt of progressbar with nested compound step samplers
1 parent 3ce284c commit 46226a8

File tree

5 files changed

+56
-59
lines changed

5 files changed

+56
-59
lines changed

pymc/step_methods/compound.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,11 @@ def _progressbar_config(n_chains=1):
189189
return columns, stats
190190

191191
@staticmethod
192-
def _make_update_stats_function():
193-
def update_stats(stats, step_stats, chain_idx):
194-
return stats
192+
def _make_update_stats_functions():
193+
def update_stats(step_stats, chain_idx):
194+
return step_stats
195195

196-
return update_stats
196+
return (update_stats,)
197197

198198
# Hack for creating the class correctly when unpickling.
199199
def __getnewargs_ex__(self):
@@ -332,16 +332,11 @@ def _progressbar_config(self, n_chains=1):
332332

333333
return columns, stats
334334

335-
def _make_update_stats_function(self):
336-
update_fns = [method._make_update_stats_function() for method in self.methods]
337-
338-
def update_stats(stats, step_stats, chain_idx):
339-
for step_stat, update_fn in zip(step_stats, update_fns):
340-
stats = update_fn(stats, step_stat, chain_idx)
341-
342-
return stats
343-
344-
return update_stats
335+
def _make_update_stats_functions(self):
336+
update_functions = []
337+
for method in self.methods:
338+
update_functions.extend(method._make_update_stats_functions())
339+
return update_functions
345340

346341

347342
def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]:

pymc/step_methods/hmc/nuts.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,19 +248,11 @@ def _progressbar_config(n_chains=1):
248248
return columns, stats
249249

250250
@staticmethod
251-
def _make_update_stats_function():
252-
def update_stats(stats, step_stats, chain_idx):
253-
if isinstance(step_stats, list):
254-
step_stats = step_stats[0]
251+
def _make_update_stats_functions():
252+
def update_stats(stats):
253+
return {key: stats[key] for key in ("diverging", "step_size", "tree_size")}
255254

256-
if not step_stats["tune"]:
257-
stats["divergences"][chain_idx] += step_stats["diverging"]
258-
259-
stats["step_size"][chain_idx] = step_stats["step_size"]
260-
stats["tree_size"][chain_idx] = step_stats["tree_size"]
261-
return stats
262-
263-
return update_stats
255+
return (update_stats,)
264256

265257

266258
# A proposal for the next position

pymc/step_methods/metropolis.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -346,18 +346,14 @@ def _progressbar_config(n_chains=1):
346346
return columns, stats
347347

348348
@staticmethod
349-
def _make_update_stats_function():
350-
def update_stats(stats, step_stats, chain_idx):
351-
if isinstance(step_stats, list):
352-
step_stats = step_stats[0]
353-
354-
stats["tune"][chain_idx] = step_stats["tune"]
355-
stats["accept_rate"][chain_idx] = step_stats["accept"]
356-
stats["scaling"][chain_idx] = step_stats["scaling"]
357-
358-
return stats
359-
360-
return update_stats
349+
def _make_update_stats_functions():
350+
def update_stats(step_stats):
351+
return {
352+
"accept_rate" if key == "accept" else key: step_stats[key]
353+
for key in ("tune", "accept", "scaling")
354+
}
355+
356+
return (update_stats,)
361357

362358

363359
def tune(scale, acc_rate):

pymc/step_methods/slicer.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,8 @@ def _progressbar_config(n_chains=1):
212212
return columns, stats
213213

214214
@staticmethod
215-
def _make_update_stats_function():
216-
def update_stats(stats, step_stats, chain_idx):
217-
if isinstance(step_stats, list):
218-
step_stats = step_stats[0]
215+
def _make_update_stats_functions():
216+
def update_stats(step_stats):
217+
return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}}
219218

220-
stats["tune"][chain_idx] = step_stats["tune"]
221-
stats["nstep_out"][chain_idx] = step_stats["nstep_out"]
222-
stats["nstep_in"][chain_idx] = step_stats["nstep_in"]
223-
224-
return stats
225-
226-
return update_stats
219+
return (update_stats,)

pymc/util.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -806,9 +806,8 @@ def __init__(
806806
progressbar=progressbar,
807807
progressbar_theme=progressbar_theme,
808808
)
809-
810809
self.progress_stats = progress_stats
811-
self.update_stats = step_method._make_update_stats_function()
810+
self.update_stats_functions = step_method._make_update_stats_functions()
812811

813812
self._show_progress = show_progress
814813
self.divergences = 0
@@ -883,27 +882,49 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
883882
if not tuning and stats and stats[0].get("diverging"):
884883
self.divergences += 1
885884

886-
self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
887-
more_updates = (
888-
{stat: value[chain_idx] for stat, value in self.progress_stats.items()}
889-
if self.full_stats
890-
else {}
891-
)
885+
if self.full_stats:
886+
# TODO: Index by chain already?
887+
chain_progress_stats = [
888+
update_states_fn(step_stats)
889+
for update_states_fn, step_stats in zip(
890+
self.update_stats_functions, stats, strict=True
891+
)
892+
]
893+
all_step_stats = {}
894+
for step_stats in chain_progress_stats:
895+
for key, val in step_stats.items():
896+
if key in all_step_stats:
897+
continue
898+
count = (
899+
sum(step_key.startswith(f"{key}_") for step_key in all_step_stats) + 1
900+
)
901+
all_step_stats[f"{key}_{count}"] = val
902+
else:
903+
all_step_stats[key] = val
904+
905+
else:
906+
all_step_stats = {}
907+
908+
# more_updates = (
909+
# {stat: value[chain_idx] for stat, value in progress_stats.items()}
910+
# if self.full_stats
911+
# else {}
912+
# )
892913

893914
self._progress.update(
894915
self.tasks[chain_idx],
895916
completed=draw,
896917
draws=draw,
897918
sampling_speed=speed,
898919
speed_unit=unit,
899-
**more_updates,
920+
**all_step_stats,
900921
)
901922

902923
if is_last:
903924
self._progress.update(
904925
self.tasks[chain_idx],
905926
draws=draw + 1 if not self.combined_progress else draw,
906-
**more_updates,
927+
**all_step_stats,
907928
refresh=True,
908929
)
909930

0 commit comments

Comments
 (0)