20
20
from dataclasses import dataclass , field
21
21
from inspect import iscoroutinefunction
22
22
from pathlib import Path
23
- from typing import Any , Callable , Generic , Literal , Union , cast
23
+ from typing import TYPE_CHECKING , Any , Callable , Generic , Literal , Union , cast
24
24
25
25
import anyio
26
26
import logfire_api
41
41
from .evaluators ._spec import EvaluatorSpec
42
42
from .evaluators .common import DEFAULT_EVALUATORS
43
43
from .evaluators .context import EvaluatorContext
44
+ from .evaluators .evaluator import EvaluatorFailure
44
45
from .otel import SpanTree
45
46
from .otel ._context_subtree import context_subtree
46
- from .reporting import EvaluationReport , ReportCase , ReportCaseAggregate
47
+ from .reporting import EvaluationReport , ReportCase , ReportCaseAggregate , ReportCaseFailure
48
+
49
+ if TYPE_CHECKING :
50
+ from tenacity import AsyncRetrying
47
51
48
52
if sys .version_info < (3 , 11 ):
49
53
from exceptiongroup import ExceptionGroup # pragma: lax no cover
84
88
85
89
86
90
_REPORT_CASES_ADAPTER = TypeAdapter (list [ReportCase ])
91
+ _REPORT_CASE_FAILURES_ADAPTER = TypeAdapter (list [ReportCaseFailure ])
87
92
_REPORT_CASE_AGGREGATE_ADAPTER = TypeAdapter (ReportCaseAggregate )
88
93
89
94
@@ -171,11 +176,6 @@ def __init__(
171
176
self .evaluators = list (evaluators )
172
177
173
178
174
- # TODO: Consider making one or more of the following changes to this type:
175
- # * Add `task: Callable[[InputsT], Awaitable[OutputT]` as a field
176
- # * Add `inputs_type`, `output_type`, etc. as kwargs on `__init__`
177
- # * Rename to `Evaluation`
178
- # TODO: Allow `task` to be sync _or_ async
179
179
class Dataset (BaseModel , Generic [InputsT , OutputT , MetadataT ], extra = 'forbid' , arbitrary_types_allowed = True ):
180
180
"""A dataset of test [cases][pydantic_evals.Case].
181
181
@@ -263,6 +263,7 @@ async def evaluate(
263
263
name : str | None = None ,
264
264
max_concurrency : int | None = None ,
265
265
progress : bool = True ,
266
+ retry : AsyncRetrying | None = None ,
266
267
) -> EvaluationReport [InputsT , OutputT , MetadataT ]:
267
268
"""Evaluates the test cases in the dataset using the given task.
268
269
@@ -292,24 +293,30 @@ async def evaluate(
292
293
293
294
async def _handle_case (case : Case [InputsT , OutputT , MetadataT ], report_case_name : str ):
294
295
async with limiter :
295
- result = await _run_task_and_evaluators (task , case , report_case_name , self .evaluators )
296
+ result = await _run_task_and_evaluators (task , case , report_case_name , self .evaluators , retry )
296
297
if progress_bar and task_id is not None : # pragma: no branch
297
298
progress_bar .update (task_id , advance = 1 )
298
299
return result
299
300
301
+ cases_and_failures = await task_group_gather (
302
+ [
303
+ lambda case = case , i = i : _handle_case (case , case .name or f'Case { i } ' )
304
+ for i , case in enumerate (self .cases , 1 )
305
+ ]
306
+ )
300
307
report = EvaluationReport (
301
308
name = name ,
302
- cases = await task_group_gather (
303
- [
304
- lambda case = case , i = i : _handle_case (case , case .name or f'Case { i } ' )
305
- for i , case in enumerate (self .cases , 1 )
306
- ]
307
- ),
309
+ cases = [x for x in cases_and_failures if isinstance (x , ReportCase )],
310
+ failures = [x for x in cases_and_failures if isinstance (x , ReportCaseFailure )],
308
311
)
309
312
# TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
310
313
eval_span .set_attribute ('cases' , _REPORT_CASES_ADAPTER .dump_python (report .cases ))
314
+ # TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
315
+ eval_span .set_attribute ('failures' , _REPORT_CASE_FAILURES_ADAPTER .dump_python (report .failures ))
311
316
# TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
312
- eval_span .set_attribute ('averages' , _REPORT_CASE_AGGREGATE_ADAPTER .dump_python (report .averages ()))
317
+ averages = report .averages ()
318
+ if averages :
319
+ eval_span .set_attribute ('averages' , _REPORT_CASE_AGGREGATE_ADAPTER .dump_python (averages ))
313
320
return report
314
321
315
322
def evaluate_sync (
@@ -817,38 +824,55 @@ def record_attribute(self, name: str, value: Any) -> None:
817
824
818
825
819
826
async def _run_task (
820
- task : Callable [[InputsT ], Awaitable [OutputT ] | OutputT ], case : Case [InputsT , OutputT , MetadataT ]
827
+ task : Callable [[InputsT ], Awaitable [OutputT ] | OutputT ],
828
+ case : Case [InputsT , OutputT , MetadataT ],
829
+ retry : AsyncRetrying | None = None ,
821
830
) -> EvaluatorContext [InputsT , OutputT , MetadataT ]:
822
831
"""Run a task on a case and return the context for evaluators.
823
832
824
833
Args:
825
834
task: The task to run.
826
835
case: The case to run the task on.
836
+ retry: The retry strategy to use.
827
837
828
838
Returns:
829
839
An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
830
840
831
841
Raises:
832
842
Exception: Any exception raised by the task.
833
843
"""
834
- task_run = _TaskRun ()
835
- if _CURRENT_TASK_RUN .get () is not None : # pragma: no cover
836
- raise RuntimeError ('A task run has already been entered. Task runs should not be nested' )
837
844
838
- # Note: the current behavior is for task execution errors to just bubble up all the way and kill the evaluation.
839
- # Should we handle them for the user in some way? If so, I guess we'd want to do that here.
840
- token = _CURRENT_TASK_RUN .set (task_run )
841
- try :
842
- with _logfire .span ('execute {task}' , task = get_unwrapped_function_name (task )) as task_span :
843
- with context_subtree () as span_tree :
845
+ async def _run_once ():
846
+ task_run_ = _TaskRun ()
847
+ if _CURRENT_TASK_RUN .get () is not None : # pragma: no cover
848
+ raise RuntimeError ('A task run has already been entered. Task runs should not be nested' )
849
+
850
+ token = _CURRENT_TASK_RUN .set (task_run_ )
851
+ try :
852
+ with (
853
+ _logfire .span ('execute {task}' , task = get_unwrapped_function_name (task )) as task_span ,
854
+ context_subtree () as span_tree_ ,
855
+ ):
844
856
t0 = time .perf_counter ()
845
857
if iscoroutinefunction (task ):
846
- task_output = cast (OutputT , await task (case .inputs ))
858
+ task_output_ = cast (OutputT , await task (case .inputs ))
847
859
else :
848
- task_output = cast (OutputT , await to_thread .run_sync (task , case .inputs ))
860
+ task_output_ = cast (OutputT , await to_thread .run_sync (task , case .inputs ))
849
861
fallback_duration = time .perf_counter () - t0
850
- finally :
851
- _CURRENT_TASK_RUN .reset (token )
862
+ duration_ = _get_span_duration (task_span , fallback_duration )
863
+ return task_run_ , task_output_ , duration_ , span_tree_
864
+ finally :
865
+ _CURRENT_TASK_RUN .reset (token )
866
+
867
+ async def _run_with_retries ():
868
+ if retry :
869
+ async for attempt in retry :
870
+ with attempt :
871
+ return await _run_once ()
872
+ # Note: the following line will be unreachable if retry is not None
873
+ return await _run_once ()
874
+
875
+ task_run , task_output , duration , span_tree = await _run_with_retries ()
852
876
853
877
if isinstance (span_tree , SpanTree ): # pragma: no branch
854
878
# TODO: Question: Should we make this metric-attributes functionality more user-configurable in some way before merging?
@@ -863,6 +887,7 @@ async def _run_task(
863
887
if not isinstance (v , (int , float )):
864
888
continue
865
889
# TODO: Revisit this choice to strip the prefix..
890
+ # TODO: Use the span-tracking-of-metrics functionality to simplify this implementation
866
891
if k .startswith ('gen_ai.usage.details.' ):
867
892
task_run .increment_metric (k .removeprefix ('gen_ai.usage.details.' ), v )
868
893
elif k .startswith ('gen_ai.usage.' ):
@@ -874,7 +899,7 @@ async def _run_task(
874
899
metadata = case .metadata ,
875
900
expected_output = case .expected_output ,
876
901
output = task_output ,
877
- duration = _get_span_duration ( task_span , fallback_duration ) ,
902
+ duration = duration ,
878
903
_span_tree = span_tree ,
879
904
attributes = task_run .attributes ,
880
905
metrics = task_run .metrics ,
@@ -886,7 +911,8 @@ async def _run_task_and_evaluators(
886
911
case : Case [InputsT , OutputT , MetadataT ],
887
912
report_case_name : str ,
888
913
dataset_evaluators : list [Evaluator [InputsT , OutputT , MetadataT ]],
889
- ) -> ReportCase [InputsT , OutputT , MetadataT ]:
914
+ retry : AsyncRetrying | None ,
915
+ ) -> ReportCase [InputsT , OutputT , MetadataT ] | ReportCaseFailure [InputsT , OutputT , MetadataT ]:
890
916
"""Run a task on a case and evaluate the results.
891
917
892
918
Args:
@@ -898,60 +924,75 @@ async def _run_task_and_evaluators(
898
924
Returns:
899
925
A ReportCase containing the evaluation results.
900
926
"""
901
- with _logfire .span (
902
- 'case: {case_name}' ,
903
- task_name = get_unwrapped_function_name (task ),
904
- case_name = report_case_name ,
905
- inputs = case .inputs ,
906
- metadata = case .metadata ,
907
- expected_output = case .expected_output ,
908
- ) as case_span :
909
- t0 = time .time ()
910
- scoring_context = await _run_task (task , case )
911
-
912
- case_span .set_attribute ('output' , scoring_context .output )
913
- case_span .set_attribute ('task_duration' , scoring_context .duration )
914
- case_span .set_attribute ('metrics' , scoring_context .metrics )
915
- case_span .set_attribute ('attributes' , scoring_context .attributes )
916
-
917
- evaluators = case .evaluators + dataset_evaluators
918
- evaluator_outputs : list [EvaluationResult ] = []
919
- if evaluators :
920
- evaluator_outputs_by_task = await task_group_gather (
921
- [lambda ev = ev : run_evaluator (ev , scoring_context ) for ev in evaluators ]
922
- )
923
- evaluator_outputs += [out for outputs in evaluator_outputs_by_task for out in outputs ]
924
-
925
- assertions , scores , labels = _group_evaluator_outputs_by_type (evaluator_outputs )
926
- case_span .set_attribute ('assertions' , _evaluation_results_adapter .dump_python (assertions ))
927
- case_span .set_attribute ('scores' , _evaluation_results_adapter .dump_python (scores ))
928
- case_span .set_attribute ('labels' , _evaluation_results_adapter .dump_python (labels ))
929
-
930
- context = case_span .context
931
- if context is None : # pragma: no cover
932
- trace_id = ''
933
- span_id = ''
934
- else :
935
- trace_id = f'{ context .trace_id :032x} '
936
- span_id = f'{ context .span_id :016x} '
937
- fallback_duration = time .time () - t0
938
-
939
- return ReportCase [InputsT , OutputT , MetadataT ](
940
- name = report_case_name ,
941
- inputs = case .inputs ,
942
- metadata = case .metadata ,
943
- expected_output = case .expected_output ,
944
- output = scoring_context .output ,
945
- metrics = scoring_context .metrics ,
946
- attributes = scoring_context .attributes ,
947
- scores = scores ,
948
- labels = labels ,
949
- assertions = assertions ,
950
- task_duration = scoring_context .duration ,
951
- total_duration = _get_span_duration (case_span , fallback_duration ),
952
- trace_id = trace_id ,
953
- span_id = span_id ,
954
- )
927
+ trace_id = ''
928
+ span_id = ''
929
+ try :
930
+ with _logfire .span (
931
+ 'case: {case_name}' ,
932
+ task_name = get_unwrapped_function_name (task ),
933
+ case_name = report_case_name ,
934
+ inputs = case .inputs ,
935
+ metadata = case .metadata ,
936
+ expected_output = case .expected_output ,
937
+ ) as case_span :
938
+ context = case_span .context
939
+ if context is not None : # pragma: no cover
940
+ trace_id = f'{ context .trace_id :032x} '
941
+ span_id = f'{ context .span_id :016x} '
942
+
943
+ t0 = time .time ()
944
+ scoring_context = await _run_task (task , case , retry )
945
+
946
+ case_span .set_attribute ('output' , scoring_context .output )
947
+ case_span .set_attribute ('task_duration' , scoring_context .duration )
948
+ case_span .set_attribute ('metrics' , scoring_context .metrics )
949
+ case_span .set_attribute ('attributes' , scoring_context .attributes )
950
+
951
+ evaluators = case .evaluators + dataset_evaluators
952
+ evaluator_outputs : list [EvaluationResult ] = []
953
+ evaluator_failures : list [EvaluatorFailure ] = []
954
+ if evaluators :
955
+ evaluator_outputs_by_task = await task_group_gather (
956
+ [lambda ev = ev : run_evaluator (ev , scoring_context ) for ev in evaluators ]
957
+ )
958
+ flattened = [out for outputs in evaluator_outputs_by_task for out in outputs ]
959
+ evaluator_outputs += [o for o in flattened if not isinstance (o , EvaluatorFailure )]
960
+ evaluator_failures += [o for o in flattened if isinstance (o , EvaluatorFailure )]
961
+
962
+ assertions , scores , labels = _group_evaluator_outputs_by_type (evaluator_outputs )
963
+ case_span .set_attribute ('assertions' , _evaluation_results_adapter .dump_python (assertions ))
964
+ case_span .set_attribute ('scores' , _evaluation_results_adapter .dump_python (scores ))
965
+ case_span .set_attribute ('labels' , _evaluation_results_adapter .dump_python (labels ))
966
+
967
+ fallback_duration = time .time () - t0
968
+
969
+ return ReportCase [InputsT , OutputT , MetadataT ](
970
+ name = report_case_name ,
971
+ inputs = case .inputs ,
972
+ metadata = case .metadata ,
973
+ expected_output = case .expected_output ,
974
+ output = scoring_context .output ,
975
+ metrics = scoring_context .metrics ,
976
+ attributes = scoring_context .attributes ,
977
+ scores = scores ,
978
+ labels = labels ,
979
+ assertions = assertions ,
980
+ task_duration = scoring_context .duration ,
981
+ total_duration = _get_span_duration (case_span , fallback_duration ),
982
+ trace_id = trace_id ,
983
+ span_id = span_id ,
984
+ evaluator_failures = evaluator_failures ,
985
+ )
986
+ except Exception as exc :
987
+ return ReportCaseFailure [InputsT , OutputT , MetadataT ](
988
+ name = report_case_name ,
989
+ inputs = case .inputs ,
990
+ metadata = case .metadata ,
991
+ expected_output = case .expected_output ,
992
+ error_msg = f'{ type (exc ).__name__ } : { exc } ' ,
993
+ trace_id = trace_id ,
994
+ span_id = span_id ,
995
+ )
955
996
956
997
957
998
_evaluation_results_adapter = TypeAdapter (Mapping [str , EvaluationResult ])
0 commit comments