|
19 | 19 |
|
20 | 20 | import numpy as np
|
21 | 21 |
|
| 22 | +from rich.progress import TextColumn |
| 23 | +from rich.table import Column |
| 24 | + |
22 | 25 | from pymc.stats.convergence import SamplerWarning
|
23 | 26 | from pymc.step_methods.compound import Competence
|
24 | 27 | from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
|
@@ -203,3 +206,27 @@ def competence(var, has_grad):
|
203 | 206 | if var.dtype in discrete_types or not has_grad:
|
204 | 207 | return Competence.INCOMPATIBLE
|
205 | 208 | return Competence.COMPATIBLE
|
| 209 | + |
| 210 | + @staticmethod |
| 211 | + def _progressbar_config(n_chains=1): |
| 212 | + columns = [ |
| 213 | + TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)), |
| 214 | + TextColumn("{task.fields[n_steps]}", table_column=Column("Grad evals", ratio=1)), |
| 215 | + ] |
| 216 | + |
| 217 | + stats = { |
| 218 | + "divergences": [0] * n_chains, |
| 219 | + "n_steps": [0] * n_chains, |
| 220 | + } |
| 221 | + |
| 222 | + return columns, stats |
| 223 | + |
| 224 | + def _make_progressbar_update_functions(self): |
| 225 | + def update_stats(stats): |
| 226 | + divergences = self.divergences |
| 227 | + return {key: stats[key] for key in ("n_steps",)} | { |
| 228 | + "failing": divergences > 0, |
| 229 | + "divergences": divergences, |
| 230 | + } |
| 231 | + |
| 232 | + return (update_stats,) |
0 commit comments