@@ -330,6 +330,18 @@ def _initialize_tasks(self):
330
330
for chain_idx in range (self .chains )
331
331
]
332
332
333
+ @staticmethod
334
+ def compute_draw_speed (elapsed , draws ):
335
+ speed = draws / max (elapsed , 1e-6 )
336
+
337
+ if speed > 1 or speed == 0 :
338
+ unit = "draws/s"
339
+ else :
340
+ unit = "s/draws"
341
+ speed = 1 / speed
342
+
343
+ return speed , unit
344
+
333
345
def update (self , chain_idx , is_last , draw , tuning , stats ):
334
346
if not self ._show_progress :
335
347
return
@@ -340,7 +352,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
340
352
chain_idx = 0
341
353
342
354
elapsed = self ._progress .tasks [chain_idx ].elapsed
343
- speed , unit = compute_draw_speed (elapsed , draw )
355
+ speed , unit = self . compute_draw_speed (elapsed , draw )
344
356
345
357
if not tuning and stats and stats [0 ].get ("diverging" ):
346
358
self .divergences += 1
@@ -365,12 +377,6 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
365
377
else :
366
378
all_step_stats = {}
367
379
368
- # more_updates = (
369
- # {stat: value[chain_idx] for stat, value in progress_stats.items()}
370
- # if self.full_stats
371
- # else {}
372
- # )
373
-
374
380
self ._progress .update (
375
381
self .tasks [chain_idx ],
376
382
completed = draw ,
@@ -415,15 +421,3 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
415
421
disable = not progressbar ,
416
422
include_headers = True ,
417
423
)
418
-
419
-
420
- def compute_draw_speed (elapsed , draws ):
421
- speed = draws / max (elapsed , 1e-6 )
422
-
423
- if speed > 1 or speed == 0 :
424
- unit = "draws/s"
425
- else :
426
- unit = "s/draws"
427
- speed = 1 / speed
428
-
429
- return speed , unit
0 commit comments