diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d9b4b000c7..efbb457f74 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -69,6 +69,7 @@ jobs: tests/distributions/test_shape_utils.py tests/distributions/test_mixture.py tests/test_testing.py + tests/test_progress_bar.py - | tests/distributions/test_continuous.py diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index f0f0eec963..71f08da826 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -39,8 +39,9 @@ import pymc from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.pytensorf import PointFunc, extract_obs_data -from pymc.util import CustomProgress, default_progress_theme, get_default_varnames +from pymc.util import get_default_varnames if TYPE_CHECKING: from pymc.backends.base import MultiTrace diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index 3e0ad532f7..4b8808a3bd 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -35,6 +35,7 @@ # SOFTWARE. import abc +import warnings from collections.abc import Sequence from functools import singledispatch diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py new file mode 100644 index 0000000000..7299584307 --- /dev/null +++ b/pymc/progress_bar.py @@ -0,0 +1,425 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterable +from typing import TYPE_CHECKING, Literal + +from rich.box import SIMPLE_HEAD +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + Task, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from rich.style import Style +from rich.table import Column, Table +from rich.theme import Theme + +if TYPE_CHECKING: + from pymc.step_methods.compound import BlockedStep, CompoundStep + +ProgressBarType = Literal[ + "combined", + "split", + "combined+stats", + "stats+combined", + "split+stats", + "stats+split", +] +default_progress_theme = Theme( + { + "bar.complete": "#1764f4", + "bar.finished": "green", + "progress.remaining": "none", + "progress.elapsed": "none", + } +) + + +class CustomProgress(Progress): + """A child of Progress that allows to disable progress bars and its container. + + The implementation simply checks an `is_enabled` flag and generates the progress bar only if + it's `True`. + """ + + def __init__(self, *args, disable=False, include_headers=False, **kwargs): + self.is_enabled = not disable + self.include_headers = include_headers + + if self.is_enabled: + super().__init__(*args, **kwargs) + + def __enter__(self): + """Enter the context manager.""" + if self.is_enabled: + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager.""" + if self.is_enabled: + super().__exit__(exc_type, exc_val, exc_tb) + + def add_task(self, *args, **kwargs): + if self.is_enabled: + return super().add_task(*args, **kwargs) + return None + + def advance(self, task_id, advance=1) -> None: + if self.is_enabled: + super().advance(task_id, advance) + return None + + def update( + self, + task_id, + *, + total=None, + completed=None, + advance=None, + description=None, + visible=None, + refresh=False, + **fields, + ): + if self.is_enabled: + super().update( + task_id, + total=total, + completed=completed, + advance=advance, + description=description, + visible=visible, + refresh=refresh, + **fields, + ) + return None + + def make_tasks_table(self, tasks: Iterable[Task]) -> Table: + """Get a table to render the Progress display. + + Unlike the parent method, this one returns a full table (not a grid), allowing for column headings. + + Parameters + ---------- + tasks: Iterable[Task] + An iterable of Task instances, one per row of the table. + + Returns + ------- + table: Table + A table instance. + """ + + def call_column(column, task): + # Subclass rich.BarColumn and add a callback method to dynamically update the display + if hasattr(column, "callbacks"): + column.callbacks(task) + + return column(task) + + table_columns = ( + ( + Column(no_wrap=True) + if isinstance(_column, str) + else _column.get_table_column().copy() + ) + for _column in self.columns + ) + if self.include_headers: + table = Table( + *table_columns, + padding=(0, 1), + expand=self.expand, + show_header=True, + show_edge=True, + box=SIMPLE_HEAD, + ) + else: + table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand) + + for task in tasks: + if task.visible: + table.add_row( + *( + ( + column.format(task=task) + if isinstance(column, str) + else call_column(column, task) + ) + for column in self.columns + ) + ) + + return table + + +class RecolorOnFailureBarColumn(BarColumn): + """Rich colorbar that changes color when a chain has detected a failure.""" + + def __init__(self, *args, failing_color="red", **kwargs): + from matplotlib.colors import to_rgb + + self.failing_color = failing_color + self.failing_rgb = [int(x * 255) for x in to_rgb(self.failing_color)] + + super().__init__(*args, **kwargs) + + self.default_complete_style = self.complete_style + self.default_finished_style = self.finished_style + + def callbacks(self, task: "Task"): + if task.fields["failing"]: + self.complete_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb)) + self.finished_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb)) + else: + # Recovered from failing yay + self.complete_style = self.default_complete_style + self.finished_style = self.default_finished_style + + +class ProgressBarManager: + """Manage progress bars displayed during sampling.""" + + def __init__( + self, + step_method: "BlockedStep | CompoundStep", + chains: int, + draws: int, + tune: int, + progressbar: bool | ProgressBarType = True, + progressbar_theme: Theme | None = None, + ): + """ + Manage progress bars displayed during sampling. + + When sampling, Step classes are responsible for computing and exposing statistics that can be reported on + progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` + and :meth:`pymc.step_methods.BlockedStep._make_progressbar_update_functions`. `_progressbar_config` reports which + columns should be displayed on the progress bar, and `_make_progressbar_update_functions` computes the statistics + that will be displayed on the progress bar. + + Parameters + ---------- + step_method: BlockedStep or CompoundStep + The step method being used to sample + chains: int + Number of chains being sampled + draws: int + Number of draws per chain + tune: int + Number of tuning steps per chain + progressbar: bool or ProgressType, optional + How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask + for one of the following: + - "combined": A single progress bar that displays the total progress across all chains. Only timing + information is shown. + - "split": A separate progress bar for each chain. Only timing information is shown. + - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all + chains. Aggregate sample statistics are also displayed. + - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain + are also displayed. + + If True, the default is "split+stats" is used. + + progressbar_theme: Theme, optional + The theme to use for the progress bar. Defaults to the default theme. + """ + if progressbar_theme is None: + progressbar_theme = default_progress_theme + + match progressbar: + case True: + self.combined_progress = False + self.full_stats = True + show_progress = True + case False: + self.combined_progress = False + self.full_stats = True + show_progress = False + case "combined": + self.combined_progress = True + self.full_stats = False + show_progress = True + case "split": + self.combined_progress = False + self.full_stats = False + show_progress = True + case "combined+stats" | "stats+combined": + self.combined_progress = True + self.full_stats = True + show_progress = True + case "split+stats" | "stats+split": + self.combined_progress = False + self.full_stats = True + show_progress = True + case _: + raise ValueError( + "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " + "one of 'combined', 'split', 'split+stats', or 'combined+stats." + ) + + progress_columns, progress_stats = step_method._progressbar_config(chains) + + self._progress = self.create_progress_bar( + progress_columns, + progressbar=progressbar, + progressbar_theme=progressbar_theme, + ) + self.progress_stats = progress_stats + self.update_stats_functions = step_method._make_progressbar_update_functions() + + self._show_progress = show_progress + self.completed_draws = 0 + self.total_draws = draws + tune + self.desc = "Sampling chain" + self.chains = chains + + self._tasks: list[Task] | None = None # type: ignore[annotation-unchecked] + + def __enter__(self): + self._initialize_tasks() + + return self._progress.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return self._progress.__exit__(exc_type, exc_val, exc_tb) + + def _initialize_tasks(self): + if self.combined_progress: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws * self.chains - 1, + chain_idx=0, + sampling_speed=0, + speed_unit="draws/s", + failing=False, + **{stat: value[0] for stat, value in self.progress_stats.items()}, + ) + ] + + else: + self.tasks = [ + self._progress.add_task( + self.desc.format(self), + completed=0, + draws=0, + total=self.total_draws - 1, + chain_idx=chain_idx, + sampling_speed=0, + speed_unit="draws/s", + failing=False, + **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, + ) + for chain_idx in range(self.chains) + ] + + @staticmethod + def compute_draw_speed(elapsed, draws): + speed = draws / max(elapsed, 1e-6) + + if speed > 1 or speed == 0: + unit = "draws/s" + else: + unit = "s/draws" + speed = 1 / speed + + return speed, unit + + def update(self, chain_idx, is_last, draw, tuning, stats): + if not self._show_progress: + return + + self.completed_draws += 1 + if self.combined_progress: + draw = self.completed_draws + chain_idx = 0 + + elapsed = self._progress.tasks[chain_idx].elapsed + speed, unit = self.compute_draw_speed(elapsed, draw) + + failing = False + all_step_stats = {} + + chain_progress_stats = [ + update_stats_fn(step_stats) + for update_stats_fn, step_stats in zip(self.update_stats_functions, stats, strict=True) + ] + for step_stats in chain_progress_stats: + for key, val in step_stats.items(): + if key == "failing": + failing |= val + continue + if not self.full_stats: + # Only care about the "failing" flag + continue + + if key in all_step_stats: + # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now + continue + else: + all_step_stats[key] = val + + self._progress.update( + self.tasks[chain_idx], + completed=draw, + draws=draw, + sampling_speed=speed, + speed_unit=unit, + failing=failing, + **all_step_stats, + ) + + if is_last: + self._progress.update( + self.tasks[chain_idx], + draws=draw + 1 if not self.combined_progress else draw, + failing=failing, + **all_step_stats, + refresh=True, + ) + + def create_progress_bar(self, step_columns, progressbar, progressbar_theme): + columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + + if self.full_stats: + columns += step_columns + + columns += [ + TextColumn( + "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", + table_column=Column("Sampling Speed", ratio=1), + ), + TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), + TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), + ] + + return CustomProgress( + RecolorOnFailureBarColumn( + table_column=Column("Progress", ratio=2), + failing_color="tab:red", + complete_style=Style.parse("rgb(31,119,180)"), # tab:blue + finished_style=Style.parse("rgb(31,119,180)"), # tab:blue + ), + *columns, + console=Console(theme=progressbar_theme), + disable=not progressbar, + include_headers=True, + ) diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 1be14f77f3..d65c6c0118 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -57,12 +57,11 @@ from pymc.distributions.shape_utils import change_dist_size from pymc.logprob.utils import rvs_in_graph from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.pytensorf import compile from pymc.util import ( - CustomProgress, RandomState, _get_seeds_per_chain, - default_progress_theme, get_default_varnames, point_wrapper, ) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index d3a02b91b6..542797caa8 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -54,6 +54,7 @@ from pymc.exceptions import SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain from pymc.model import Model, modelcontext +from pymc.progress_bar import ProgressBarManager, ProgressBarType, default_progress_theme from pymc.sampling.parallel import Draw, _cpu_count from pymc.sampling.population import _sample_population from pymc.stats.convergence import ( @@ -65,12 +66,9 @@ from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( - ProgressBarManager, - ProgressBarType, RandomSeed, RandomState, _get_seeds_per_chain, - default_progress_theme, drop_warning_stat, get_random_generator, get_untransformed_name, diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index af2106ce6f..6e229b9606 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -33,10 +33,9 @@ from pymc.backends.zarr import ZarrChain from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError +from pymc.progress_bar import ProgressBarManager, default_progress_theme from pymc.util import ( - ProgressBarManager, RandomGeneratorState, - default_progress_theme, get_state_from_generator, random_generator_from_state, ) diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 92de63d0c2..5bd1771704 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -30,6 +30,7 @@ from pymc.backends.zarr import ZarrChain from pymc.initial_point import PointType from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress from pymc.stats.convergence import log_warning_stats from pymc.step_methods import CompoundStep from pymc.step_methods.arraystep import ( @@ -39,7 +40,6 @@ ) from pymc.step_methods.compound import StepMethodState from pymc.step_methods.metropolis import DEMetropolis -from pymc.util import CustomProgress __all__ = () diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index f3176f464b..5afd398281 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -39,10 +39,11 @@ from pymc.distributions.distribution import _support_point from pymc.logprob.abstract import _icdf, _logcdf, _logprob from pymc.model import Model, modelcontext +from pymc.progress_bar import CustomProgress from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH from pymc.stats.convergence import log_warnings, run_convergence_checks -from pymc.util import CustomProgress, RandomState, _get_seeds_per_chain +from pymc.util import RandomState, _get_seeds_per_chain def sample_smc( diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d07b070f0f..a9cae903f0 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -189,11 +189,11 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - return stats + def _make_progressbar_update_functions(): + def update_stats(step_stats): + return step_stats - return update_stats + return (update_stats,) # Hack for creating the class correctly when unpickling. def __getnewargs_ex__(self): @@ -332,16 +332,11 @@ def _progressbar_config(self, n_chains=1): return columns, stats - def _make_update_stats_function(self): - update_fns = [method._make_update_stats_function() for method in self.methods] - - def update_stats(stats, step_stats, chain_idx): - for step_stat, update_fn in zip(step_stats, update_fns): - stats = update_fn(stats, step_stat, chain_idx) - - return stats - - return update_stats + def _make_progressbar_update_functions(self): + update_functions = [] + for method in self.methods: + update_functions.extend(method._make_progressbar_update_functions()) + return update_functions def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index e8c96e8c4b..297b095e23 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -184,6 +184,7 @@ def __init__( self._step_rand = step_rand self._num_divs_sample = 0 + self.divergences = 0 @abstractmethod def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData: @@ -266,11 +267,15 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: divergence_info=info_store, ) + diverging = bool(hmc_step.divergence_info) + if not self.tune: + self.divergences += diverging self.iter_count += 1 stats: dict[str, Any] = { "tune": self.tune, - "diverging": bool(hmc_step.divergence_info), + "diverging": diverging, + "divergences": self.divergences, "perf_counter_diff": perf_end - perf_start, "process_time_diff": process_end - process_start, "perf_counter_start": perf_start, @@ -288,6 +293,8 @@ def reset_tuning(self, start=None): self.reset(start=None) def reset(self, start=None): + self.iter_count = 0 + self.divergences = 0 self.tune = True self.potential.reset() diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 565c1fd78b..1697341bc8 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -19,6 +19,9 @@ import numpy as np +from rich.progress import TextColumn +from rich.table import Column + from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData @@ -55,6 +58,7 @@ class HamiltonianMC(BaseHMC): "accept": (np.float64, []), "diverging": (bool, []), "energy_error": (np.float64, []), + "divergences": (np.int64, []), "energy": (np.float64, []), "path_length": (np.float64, []), "accepted": (bool, []), @@ -202,3 +206,32 @@ def competence(var, has_grad): if var.dtype in discrete_types or not has_grad: return Competence.INCOMPATIBLE return Competence.COMPATIBLE + + @staticmethod + def _progressbar_config(n_chains=1): + columns = [ + TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)), + TextColumn("{task.fields[n_steps]}", table_column=Column("Grad evals", ratio=1)), + ] + + stats = { + "divergences": [0] * n_chains, + "n_steps": [0] * n_chains, + } + + return columns, stats + + @staticmethod + def _make_progressbar_update_functions(): + def update_stats(stats): + return { + key: stats[key] + for key in ( + "divergences", + "n_steps", + ) + } | { + "failing": stats["divergences"] > 0, + } + + return (update_stats,) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 18707c3592..0f19d3c087 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -115,6 +115,7 @@ class NUTS(BaseHMC): "step_size_bar": (np.float64, []), "tree_size": (np.float64, []), "diverging": (bool, []), + "divergences": (int, []), "energy_error": (np.float64, []), "energy": (np.float64, []), "max_energy_error": (np.float64, []), @@ -248,19 +249,13 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_progressbar_update_functions(): + def update_stats(stats): + return {key: stats[key] for key in ("divergences", "step_size", "tree_size")} | { + "failing": stats["divergences"] > 0, + } - if not step_stats["tune"]: - stats["divergences"][chain_idx] += step_stats["diverging"] - - stats["step_size"][chain_idx] = step_stats["step_size"] - stats["tree_size"][chain_idx] = step_stats["tree_size"] - return stats - - return update_stats + return (update_stats,) # A proposal for the next position diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653d..2cd2e1369e 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -346,18 +346,14 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_progressbar_update_functions(): + def update_stats(step_stats): + return { + "accept_rate" if key == "accept" else key: step_stats[key] + for key in ("tune", "accept", "scaling") + } - stats["tune"][chain_idx] = step_stats["tune"] - stats["accept_rate"][chain_idx] = step_stats["accept"] - stats["scaling"][chain_idx] = step_stats["scaling"] - - return stats - - return update_stats + return (update_stats,) def tune(scale, acc_rate): @@ -684,7 +680,6 @@ def competence(var): class CategoricalGibbsMetropolisState(StepMethodState): shuffle_dims: bool dimcats: list[tuple] - tune: bool class CategoricalGibbsMetropolis(ArrayStep): @@ -767,10 +762,6 @@ def __init__( else: raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'") - # Doesn't actually tune, but it's required to emit a sampler stat - # that indicates whether a draw was done in a tuning phase. - self.tune = True - if compile_kwargs is None: compile_kwargs = {} super().__init__(vars, [model.compile_logp(**compile_kwargs)], blocked=blocked, rng=rng) @@ -800,10 +791,8 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType if accepted: logp_curr = logp_prop - stats = { - "tune": self.tune, - } - return q, [stats] + # This step doesn't have any tunable parameters + return q, [{"tune": False}] def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -820,7 +809,8 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType for dim, k in dimcats: logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k) - return q, [] + # This step doesn't have any tunable parameters + return q, [{"tune": False}] def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: raise NotImplementedError() diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 9c10acfdf4..180ac1c882 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -212,15 +212,8 @@ def _progressbar_config(n_chains=1): return columns, stats @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_progressbar_update_functions(): + def update_stats(step_stats): + return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} - stats["tune"][chain_idx] = step_stats["tune"] - stats["nstep_out"][chain_idx] = step_stats["nstep_out"] - stats["nstep_in"][chain_idx] = step_stats["nstep_in"] - - return stats - - return update_stats + return (update_stats,) diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 2fbbba6339..1385f33483 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -36,9 +36,8 @@ from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.util import ( - CustomProgress, - default_progress_theme, get_default_varnames, get_value_vars_from_user_vars, ) diff --git a/pymc/util.py b/pymc/util.py index ad9256adda..3f108b8b03 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -16,9 +16,9 @@ import re from collections import namedtuple -from collections.abc import Iterable, Sequence +from collections.abc import Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Literal, NewType, cast +from typing import NewType, cast import arviz import cloudpickle @@ -28,47 +28,11 @@ from cachetools import LRUCache, cachedmethod from pytensor import Variable from pytensor.compile import SharedVariable -from rich.box import SIMPLE_HEAD -from rich.console import Console -from rich.progress import ( - BarColumn, - Progress, - Task, - TextColumn, - TimeElapsedColumn, - TimeRemainingColumn, -) -from rich.style import Style -from rich.table import Column, Table -from rich.theme import Theme from pymc.exceptions import BlockModelAccessError -if TYPE_CHECKING: - from pymc.step_methods.compound import BlockedStep, CompoundStep - - -ProgressBarType = Literal[ - "combined", - "split", - "combined+stats", - "stats+combined", - "split+stats", - "stats+split", -] - - VarName = NewType("VarName", str) -default_progress_theme = Theme( - { - "bar.complete": "#1764f4", - "bar.finished": "green", - "progress.remaining": "none", - "progress.elapsed": "none", - } -) - class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs.""" @@ -532,368 +496,6 @@ def makeiter(a): return [a] -class CustomProgress(Progress): - """A child of Progress that allows to disable progress bars and its container. - - The implementation simply checks an `is_enabled` flag and generates the progress bar only if - it's `True`. - """ - - def __init__(self, *args, disable=False, include_headers=False, **kwargs): - self.is_enabled = not disable - self.include_headers = include_headers - - if self.is_enabled: - super().__init__(*args, **kwargs) - - def __enter__(self): - """Enter the context manager.""" - if self.is_enabled: - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Exit the context manager.""" - if self.is_enabled: - super().__exit__(exc_type, exc_val, exc_tb) - - def add_task(self, *args, **kwargs): - if self.is_enabled: - return super().add_task(*args, **kwargs) - return None - - def advance(self, task_id, advance=1) -> None: - if self.is_enabled: - super().advance(task_id, advance) - return None - - def update( - self, - task_id, - *, - total=None, - completed=None, - advance=None, - description=None, - visible=None, - refresh=False, - **fields, - ): - if self.is_enabled: - super().update( - task_id, - total=total, - completed=completed, - advance=advance, - description=description, - visible=visible, - refresh=refresh, - **fields, - ) - return None - - def make_tasks_table(self, tasks: Iterable[Task]) -> Table: - """Get a table to render the Progress display. - - Unlike the parent method, this one returns a full table (not a grid), allowing for column headings. - - Parameters - ---------- - tasks: Iterable[Task] - An iterable of Task instances, one per row of the table. - - Returns - ------- - table: Table - A table instance. - """ - - def call_column(column, task): - # Subclass rich.BarColumn and add a callback method to dynamically update the display - if hasattr(column, "callbacks"): - column.callbacks(task) - - return column(task) - - table_columns = ( - ( - Column(no_wrap=True) - if isinstance(_column, str) - else _column.get_table_column().copy() - ) - for _column in self.columns - ) - if self.include_headers: - table = Table( - *table_columns, - padding=(0, 1), - expand=self.expand, - show_header=True, - show_edge=True, - box=SIMPLE_HEAD, - ) - else: - table = Table.grid(*table_columns, padding=(0, 1), expand=self.expand) - - for task in tasks: - if task.visible: - table.add_row( - *( - ( - column.format(task=task) - if isinstance(column, str) - else call_column(column, task) - ) - for column in self.columns - ) - ) - - return table - - -class DivergenceBarColumn(BarColumn): - """Rich colorbar that changes color when a chain has detected a divergence.""" - - def __init__(self, *args, diverging_color="red", **kwargs): - from matplotlib.colors import to_rgb - - self.diverging_color = diverging_color - self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)] - - super().__init__(*args, **kwargs) - - self.non_diverging_style = self.complete_style - self.non_diverging_finished_style = self.finished_style - - def callbacks(self, task: "Task"): - divergences = task.fields.get("divergences", 0) - if isinstance(divergences, float | int) and divergences > 0: - self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) - self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb)) - else: - self.complete_style = self.non_diverging_style - self.finished_style = self.non_diverging_finished_style - - -class ProgressBarManager: - """Manage progress bars displayed during sampling.""" - - def __init__( - self, - step_method: "BlockedStep | CompoundStep", - chains: int, - draws: int, - tune: int, - progressbar: bool | ProgressBarType = True, - progressbar_theme: Theme | None = None, - ): - """ - Manage progress bars displayed during sampling. - - When sampling, Step classes are responsible for computing and exposing statistics that can be reported on - progress bars. Each Step implements two class methods: :meth:`pymc.step_methods.BlockedStep._progressbar_config` - and :meth:`pymc.step_methods.BlockedStep._make_update_stats_function`. `_progressbar_config` reports which - columns should be displayed on the progress bar, and `_make_update_stats_function` computes the statistics - that will be displayed on the progress bar. - - Parameters - ---------- - step_method: BlockedStep or CompoundStep - The step method being used to sample - chains: int - Number of chains being sampled - draws: int - Number of draws per chain - tune: int - Number of tuning steps per chain - progressbar: bool or ProgressType, optional - How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask - for one of the following: - - "combined": A single progress bar that displays the total progress across all chains. Only timing - information is shown. - - "split": A separate progress bar for each chain. Only timing information is shown. - - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all - chains. Aggregate sample statistics are also displayed. - - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain - are also displayed. - - If True, the default is "split+stats" is used. - - progressbar_theme: Theme, optional - The theme to use for the progress bar. Defaults to the default theme. - """ - if progressbar_theme is None: - progressbar_theme = default_progress_theme - - match progressbar: - case True: - self.combined_progress = False - self.full_stats = True - show_progress = True - case False: - self.combined_progress = False - self.full_stats = True - show_progress = False - case "combined": - self.combined_progress = True - self.full_stats = False - show_progress = True - case "split": - self.combined_progress = False - self.full_stats = False - show_progress = True - case "combined+stats" | "stats+combined": - self.combined_progress = True - self.full_stats = True - show_progress = True - case "split+stats" | "stats+split": - self.combined_progress = False - self.full_stats = True - show_progress = True - case _: - raise ValueError( - "Invalid value for `progressbar`. Valid values are True (default), False (no progress bar), " - "one of 'combined', 'split', 'split+stats', or 'combined+stats." - ) - - progress_columns, progress_stats = step_method._progressbar_config(chains) - - self._progress = self.create_progress_bar( - progress_columns, - progressbar=progressbar, - progressbar_theme=progressbar_theme, - ) - - self.progress_stats = progress_stats - self.update_stats = step_method._make_update_stats_function() - - self._show_progress = show_progress - self.divergences = 0 - self.completed_draws = 0 - self.total_draws = draws + tune - self.desc = "Sampling chain" - self.chains = chains - - self._tasks: list[Task] | None = None # type: ignore[annotation-unchecked] - - def __enter__(self): - self._initialize_tasks() - - return self._progress.__enter__() - - def __exit__(self, exc_type, exc_val, exc_tb): - return self._progress.__exit__(exc_type, exc_val, exc_tb) - - def _initialize_tasks(self): - if self.combined_progress: - self.tasks = [ - self._progress.add_task( - self.desc.format(self), - completed=0, - draws=0, - total=self.total_draws * self.chains - 1, - chain_idx=0, - sampling_speed=0, - speed_unit="draws/s", - **{stat: value[0] for stat, value in self.progress_stats.items()}, - ) - ] - - else: - self.tasks = [ - self._progress.add_task( - self.desc.format(self), - completed=0, - draws=0, - total=self.total_draws - 1, - chain_idx=chain_idx, - sampling_speed=0, - speed_unit="draws/s", - **{stat: value[chain_idx] for stat, value in self.progress_stats.items()}, - ) - for chain_idx in range(self.chains) - ] - - def update(self, chain_idx, is_last, draw, tuning, stats): - if not self._show_progress: - return - - self.completed_draws += 1 - if self.combined_progress: - draw = self.completed_draws - chain_idx = 0 - - elapsed = self._progress.tasks[chain_idx].elapsed - speed, unit = compute_draw_speed(elapsed, draw) - - if not tuning and stats and stats[0].get("diverging"): - self.divergences += 1 - - self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx) - more_updates = ( - {stat: value[chain_idx] for stat, value in self.progress_stats.items()} - if self.full_stats - else {} - ) - - self._progress.update( - self.tasks[chain_idx], - completed=draw, - draws=draw, - sampling_speed=speed, - speed_unit=unit, - **more_updates, - ) - - if is_last: - self._progress.update( - self.tasks[chain_idx], - draws=draw + 1 if not self.combined_progress else draw, - **more_updates, - refresh=True, - ) - - def create_progress_bar(self, step_columns, progressbar, progressbar_theme): - columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] - - if self.full_stats: - columns += step_columns - - columns += [ - TextColumn( - "{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}", - table_column=Column("Sampling Speed", ratio=1), - ), - TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)), - TimeRemainingColumn(table_column=Column("Remaining", ratio=1)), - ] - - return CustomProgress( - DivergenceBarColumn( - table_column=Column("Progress", ratio=2), - diverging_color="tab:red", - complete_style=Style.parse("rgb(31,119,180)"), # tab:blue - finished_style=Style.parse("rgb(31,119,180)"), # tab:blue - ), - *columns, - console=Console(theme=progressbar_theme), - disable=not progressbar, - include_headers=True, - ) - - -def compute_draw_speed(elapsed, draws): - speed = draws / max(elapsed, 1e-6) - - if speed > 1 or speed == 0: - unit = "draws/s" - else: - unit = "s/draws" - speed = 1 / speed - - return speed, unit - - RandomGeneratorState = namedtuple("RandomGeneratorState", ["bit_generator_state", "seed_seq_state"]) diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index d9da7fb786..b83c1db4a3 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -23,7 +23,7 @@ import pymc as pm -from pymc.util import CustomProgress, default_progress_theme +from pymc.progress_bar import CustomProgress, default_progress_theme from pymc.variational import test_functions from pymc.variational.approximations import Empirical, FullRank, MeanField from pymc.variational.operators import KL, KSD diff --git a/tests/step_methods/hmc/test_nuts.py b/tests/step_methods/hmc/test_nuts.py index 432418a33a..8d497f3011 100644 --- a/tests/step_methods/hmc/test_nuts.py +++ b/tests/step_methods/hmc/test_nuts.py @@ -148,6 +148,7 @@ def test_sampler_stats(self): expected_stat_names = { "depth", "diverging", + "divergences", "energy", "energy_error", "model_logp", diff --git a/tests/test_progress_bar.py b/tests/test_progress_bar.py new file mode 100644 index 0000000000..6687db1ae8 --- /dev/null +++ b/tests/test_progress_bar.py @@ -0,0 +1,46 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pymc as pm + + +def test_progressbar_nested_compound(): + # Regression test for https://github.com/pymc-devs/pymc/issues/7721 + + with pm.Model(): + a = pm.Poisson("a", mu=10) + b = pm.Binomial("b", n=a, p=0.8) + c = pm.Poisson("c", mu=11) + d = pm.Dirichlet("d", a=[c, b]) + + step = pm.CompoundStep( + [ + pm.CompoundStep([pm.Metropolis(a), pm.Metropolis(b), pm.Metropolis(c)]), + pm.NUTS([d]), + ] + ) + + kwargs = { + "draws": 10, + "tune": 10, + "chains": 2, + "compute_convergence_checks": False, + "step": step, + } + + # We don't parametrize to avoid recompiling the model functions + for cores in (1, 2): + pm.sample(**kwargs, cores=cores, progressbar=True) # default is split+stats + pm.sample(**kwargs, cores=cores, progressbar="combined") + pm.sample(**kwargs, cores=cores, progressbar="split") + pm.sample(**kwargs, cores=cores, progressbar=False)