@@ -168,28 +168,28 @@ def call_column(column, task):
168
168
return table
169
169
170
170
171
- class DivergenceBarColumn (BarColumn ):
172
- """Rich colorbar that changes color when a chain has detected a divergence ."""
171
+ class RecolorOnFailureBarColumn (BarColumn ):
172
+ """Rich colorbar that changes color when a chain has detected a failure ."""
173
173
174
- def __init__ (self , * args , diverging_color = "red" , ** kwargs ):
174
+ def __init__ (self , * args , failing_color = "red" , ** kwargs ):
175
175
from matplotlib .colors import to_rgb
176
176
177
- self .diverging_color = diverging_color
178
- self .diverging_rgb = [int (x * 255 ) for x in to_rgb (self .diverging_color )]
177
+ self .failing_color = failing_color
178
+ self .failing_rgb = [int (x * 255 ) for x in to_rgb (self .failing_color )]
179
179
180
180
super ().__init__ (* args , ** kwargs )
181
181
182
- self .non_diverging_style = self .complete_style
183
- self .non_diverging_finished_style = self .finished_style
182
+ self .default_complete_style = self .complete_style
183
+ self .default_finished_style = self .finished_style
184
184
185
185
def callbacks (self , task : "Task" ):
186
- divergences = task .fields .get ("divergences" , 0 )
187
- if isinstance (divergences , float | int ) and divergences > 0 :
188
- self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
189
- self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
186
+ if task .fields ["failing" ]:
187
+ self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
188
+ self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
190
189
else :
191
- self .complete_style = self .non_diverging_style
192
- self .finished_style = self .non_diverging_finished_style
190
+ # Recovered from failing yay
191
+ self .complete_style = self .default_complete_style
192
+ self .finished_style = self .default_finished_style
193
193
194
194
195
195
class ProgressBarManager :
@@ -284,7 +284,6 @@ def __init__(
284
284
self .update_stats_functions = step_method ._make_progressbar_update_functions ()
285
285
286
286
self ._show_progress = show_progress
287
- self .divergences = 0
288
287
self .completed_draws = 0
289
288
self .total_draws = draws + tune
290
289
self .desc = "Sampling chain"
@@ -311,6 +310,7 @@ def _initialize_tasks(self):
311
310
chain_idx = 0 ,
312
311
sampling_speed = 0 ,
313
312
speed_unit = "draws/s" ,
313
+ failing = False ,
314
314
** {stat : value [0 ] for stat , value in self .progress_stats .items ()},
315
315
)
316
316
]
@@ -325,6 +325,7 @@ def _initialize_tasks(self):
325
325
chain_idx = chain_idx ,
326
326
sampling_speed = 0 ,
327
327
speed_unit = "draws/s" ,
328
+ failing = False ,
328
329
** {stat : value [chain_idx ] for stat , value in self .progress_stats .items ()},
329
330
)
330
331
for chain_idx in range (self .chains )
@@ -354,42 +355,43 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
354
355
elapsed = self ._progress .tasks [chain_idx ].elapsed
355
356
speed , unit = self .compute_draw_speed (elapsed , draw )
356
357
357
- if not tuning and stats and stats [ 0 ]. get ( "diverging" ):
358
- self . divergences += 1
358
+ failing = False
359
+ all_step_stats = {}
359
360
360
- if self .full_stats :
361
- # TODO: Index by chain already?
362
- chain_progress_stats = [
363
- update_states_fn (step_stats )
364
- for update_states_fn , step_stats in zip (
365
- self .update_stats_functions , stats , strict = True
366
- )
367
- ]
368
- all_step_stats = {}
369
- for step_stats in chain_progress_stats :
370
- for key , val in step_stats .items ():
371
- if key in all_step_stats :
372
- # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now
373
- continue
374
- else :
375
- all_step_stats [key ] = val
376
-
377
- else :
378
- all_step_stats = {}
361
+ chain_progress_stats = [
362
+ update_stats_fn (step_stats )
363
+ for update_stats_fn , step_stats in zip (self .update_stats_functions , stats , strict = True )
364
+ ]
365
+ for step_stats in chain_progress_stats :
366
+ for key , val in step_stats .items ():
367
+ if key == "failing" :
368
+ failing |= val
369
+ continue
370
+ if not self .full_stats :
371
+ # Only care about the "failing" flag
372
+ continue
373
+
374
+ if key in all_step_stats :
375
+ # TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now
376
+ continue
377
+ else :
378
+ all_step_stats [key ] = val
379
379
380
380
self ._progress .update (
381
381
self .tasks [chain_idx ],
382
382
completed = draw ,
383
383
draws = draw ,
384
384
sampling_speed = speed ,
385
385
speed_unit = unit ,
386
+ failing = failing ,
386
387
** all_step_stats ,
387
388
)
388
389
389
390
if is_last :
390
391
self ._progress .update (
391
392
self .tasks [chain_idx ],
392
393
draws = draw + 1 if not self .combined_progress else draw ,
394
+ failing = failing ,
393
395
** all_step_stats ,
394
396
refresh = True ,
395
397
)
@@ -410,9 +412,9 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
410
412
]
411
413
412
414
return CustomProgress (
413
- DivergenceBarColumn (
415
+ RecolorOnFailureBarColumn (
414
416
table_column = Column ("Progress" , ratio = 2 ),
415
- diverging_color = "tab:red" ,
417
+ failing_color = "tab:red" ,
416
418
complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
417
419
finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
418
420
),
0 commit comments