Skip to content

Commit c341491

Browse files
Add progressbar config to HamiltonianMC
1 parent 64a092c commit c341491

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

pymc/step_methods/hmc/hmc.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
import numpy as np
2121

22+
from rich.progress import TextColumn
23+
from rich.table import Column
24+
2225
from pymc.stats.convergence import SamplerWarning
2326
from pymc.step_methods.compound import Competence
2427
from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
@@ -203,3 +206,27 @@ def competence(var, has_grad):
203206
if var.dtype in discrete_types or not has_grad:
204207
return Competence.INCOMPATIBLE
205208
return Competence.COMPATIBLE
209+
210+
@staticmethod
211+
def _progressbar_config(n_chains=1):
212+
columns = [
213+
TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)),
214+
TextColumn("{task.fields[n_steps]}", table_column=Column("Grad evals", ratio=1)),
215+
]
216+
217+
stats = {
218+
"divergences": [0] * n_chains,
219+
"n_steps": [0] * n_chains,
220+
}
221+
222+
return columns, stats
223+
224+
def _make_progressbar_update_functions(self):
225+
def update_stats(stats):
226+
divergences = self.divergences
227+
return {key: stats[key] for key in ("n_steps",)} | {
228+
"failing": divergences > 0,
229+
"divergences": divergences,
230+
}
231+
232+
return (update_stats,)

0 commit comments

Comments
 (0)