Skip to content

Commit 6bb2a3f

Browse files
committed
Add retry support and graceful error handling for evals
1 parent 08140ba commit 6bb2a3f

File tree

8 files changed

+491
-198
lines changed

8 files changed

+491
-198
lines changed

examples/pydantic_ai_examples/evals/example_03_unit_testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def evaluate_dataset():
2929
report = dataset.evaluate_sync(infer_time_range)
3030
print(report)
3131

32-
assertion_pass_rate = report.averages().assertions
32+
averages = report.averages()
33+
assert averages is not None
34+
assertion_pass_rate = averages.assertions
3335
assert assertion_pass_rate is not None, 'There should be at least one assertion'
3436
assert assertion_pass_rate > 0.9, (
3537
f'The assertion pass rate was {assertion_pass_rate:.1%}; it should be above 90%.'

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 126 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dataclasses import dataclass, field
2121
from inspect import iscoroutinefunction
2222
from pathlib import Path
23-
from typing import Any, Callable, Generic, Literal, Union, cast
23+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
2424

2525
import anyio
2626
import logfire_api
@@ -41,9 +41,13 @@
4141
from .evaluators._spec import EvaluatorSpec
4242
from .evaluators.common import DEFAULT_EVALUATORS
4343
from .evaluators.context import EvaluatorContext
44+
from .evaluators.evaluator import EvaluatorFailure
4445
from .otel import SpanTree
4546
from .otel._context_subtree import context_subtree
46-
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate
47+
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate, ReportCaseFailure
48+
49+
if TYPE_CHECKING:
50+
from tenacity import AsyncRetrying
4751

4852
if sys.version_info < (3, 11):
4953
from exceptiongroup import ExceptionGroup # pragma: lax no cover
@@ -84,6 +88,7 @@
8488

8589

8690
_REPORT_CASES_ADAPTER = TypeAdapter(list[ReportCase])
91+
_REPORT_CASE_FAILURES_ADAPTER = TypeAdapter(list[ReportCaseFailure])
8792
_REPORT_CASE_AGGREGATE_ADAPTER = TypeAdapter(ReportCaseAggregate)
8893

8994

@@ -171,11 +176,6 @@ def __init__(
171176
self.evaluators = list(evaluators)
172177

173178

174-
# TODO: Consider making one or more of the following changes to this type:
175-
# * Add `task: Callable[[InputsT], Awaitable[OutputT]` as a field
176-
# * Add `inputs_type`, `output_type`, etc. as kwargs on `__init__`
177-
# * Rename to `Evaluation`
178-
# TODO: Allow `task` to be sync _or_ async
179179
class Dataset(BaseModel, Generic[InputsT, OutputT, MetadataT], extra='forbid', arbitrary_types_allowed=True):
180180
"""A dataset of test [cases][pydantic_evals.Case].
181181
@@ -263,6 +263,7 @@ async def evaluate(
263263
name: str | None = None,
264264
max_concurrency: int | None = None,
265265
progress: bool = True,
266+
retry: AsyncRetrying | None = None,
266267
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
267268
"""Evaluates the test cases in the dataset using the given task.
268269
@@ -292,24 +293,30 @@ async def evaluate(
292293

293294
async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
294295
async with limiter:
295-
result = await _run_task_and_evaluators(task, case, report_case_name, self.evaluators)
296+
result = await _run_task_and_evaluators(task, case, report_case_name, self.evaluators, retry)
296297
if progress_bar and task_id is not None: # pragma: no branch
297298
progress_bar.update(task_id, advance=1)
298299
return result
299300

301+
cases_and_failures = await task_group_gather(
302+
[
303+
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
304+
for i, case in enumerate(self.cases, 1)
305+
]
306+
)
300307
report = EvaluationReport(
301308
name=name,
302-
cases=await task_group_gather(
303-
[
304-
lambda case=case, i=i: _handle_case(case, case.name or f'Case {i}')
305-
for i, case in enumerate(self.cases, 1)
306-
]
307-
),
309+
cases=[x for x in cases_and_failures if isinstance(x, ReportCase)],
310+
failures=[x for x in cases_and_failures if isinstance(x, ReportCaseFailure)],
308311
)
309312
# TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
310313
eval_span.set_attribute('cases', _REPORT_CASES_ADAPTER.dump_python(report.cases))
314+
# TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
315+
eval_span.set_attribute('failures', _REPORT_CASE_FAILURES_ADAPTER.dump_python(report.failures))
311316
# TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
312-
eval_span.set_attribute('averages', _REPORT_CASE_AGGREGATE_ADAPTER.dump_python(report.averages()))
317+
averages = report.averages()
318+
if averages:
319+
eval_span.set_attribute('averages', _REPORT_CASE_AGGREGATE_ADAPTER.dump_python(averages))
313320
return report
314321

315322
def evaluate_sync(
@@ -817,38 +824,55 @@ def record_attribute(self, name: str, value: Any) -> None:
817824

818825

819826
async def _run_task(
820-
task: Callable[[InputsT], Awaitable[OutputT] | OutputT], case: Case[InputsT, OutputT, MetadataT]
827+
task: Callable[[InputsT], Awaitable[OutputT] | OutputT],
828+
case: Case[InputsT, OutputT, MetadataT],
829+
retry: AsyncRetrying | None = None,
821830
) -> EvaluatorContext[InputsT, OutputT, MetadataT]:
822831
"""Run a task on a case and return the context for evaluators.
823832
824833
Args:
825834
task: The task to run.
826835
case: The case to run the task on.
836+
retry: The retry strategy to use.
827837
828838
Returns:
829839
An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
830840
831841
Raises:
832842
Exception: Any exception raised by the task.
833843
"""
834-
task_run = _TaskRun()
835-
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
836-
raise RuntimeError('A task run has already been entered. Task runs should not be nested')
837844

838-
# Note: the current behavior is for task execution errors to just bubble up all the way and kill the evaluation.
839-
# Should we handle them for the user in some way? If so, I guess we'd want to do that here.
840-
token = _CURRENT_TASK_RUN.set(task_run)
841-
try:
842-
with _logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span:
843-
with context_subtree() as span_tree:
845+
async def _run_once():
846+
task_run_ = _TaskRun()
847+
if _CURRENT_TASK_RUN.get() is not None: # pragma: no cover
848+
raise RuntimeError('A task run has already been entered. Task runs should not be nested')
849+
850+
token = _CURRENT_TASK_RUN.set(task_run_)
851+
try:
852+
with (
853+
_logfire.span('execute {task}', task=get_unwrapped_function_name(task)) as task_span,
854+
context_subtree() as span_tree_,
855+
):
844856
t0 = time.perf_counter()
845857
if iscoroutinefunction(task):
846-
task_output = cast(OutputT, await task(case.inputs))
858+
task_output_ = cast(OutputT, await task(case.inputs))
847859
else:
848-
task_output = cast(OutputT, await to_thread.run_sync(task, case.inputs))
860+
task_output_ = cast(OutputT, await to_thread.run_sync(task, case.inputs))
849861
fallback_duration = time.perf_counter() - t0
850-
finally:
851-
_CURRENT_TASK_RUN.reset(token)
862+
duration_ = _get_span_duration(task_span, fallback_duration)
863+
return task_run_, task_output_, duration_, span_tree_
864+
finally:
865+
_CURRENT_TASK_RUN.reset(token)
866+
867+
async def _run_with_retries():
868+
if retry:
869+
async for attempt in retry:
870+
with attempt:
871+
return await _run_once()
872+
# Note: the following line will be unreachable if retry is not None
873+
return await _run_once()
874+
875+
task_run, task_output, duration, span_tree = await _run_with_retries()
852876

853877
if isinstance(span_tree, SpanTree): # pragma: no branch
854878
# TODO: Question: Should we make this metric-attributes functionality more user-configurable in some way before merging?
@@ -863,6 +887,7 @@ async def _run_task(
863887
if not isinstance(v, (int, float)):
864888
continue
865889
# TODO: Revisit this choice to strip the prefix..
890+
# TODO: Use the span-tracking-of-metrics functionality to simplify this implementation
866891
if k.startswith('gen_ai.usage.details.'):
867892
task_run.increment_metric(k.removeprefix('gen_ai.usage.details.'), v)
868893
elif k.startswith('gen_ai.usage.'):
@@ -874,7 +899,7 @@ async def _run_task(
874899
metadata=case.metadata,
875900
expected_output=case.expected_output,
876901
output=task_output,
877-
duration=_get_span_duration(task_span, fallback_duration),
902+
duration=duration,
878903
_span_tree=span_tree,
879904
attributes=task_run.attributes,
880905
metrics=task_run.metrics,
@@ -886,7 +911,8 @@ async def _run_task_and_evaluators(
886911
case: Case[InputsT, OutputT, MetadataT],
887912
report_case_name: str,
888913
dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
889-
) -> ReportCase[InputsT, OutputT, MetadataT]:
914+
retry: AsyncRetrying | None,
915+
) -> ReportCase[InputsT, OutputT, MetadataT] | ReportCaseFailure[InputsT, OutputT, MetadataT]:
890916
"""Run a task on a case and evaluate the results.
891917
892918
Args:
@@ -898,60 +924,75 @@ async def _run_task_and_evaluators(
898924
Returns:
899925
A ReportCase containing the evaluation results.
900926
"""
901-
with _logfire.span(
902-
'case: {case_name}',
903-
task_name=get_unwrapped_function_name(task),
904-
case_name=report_case_name,
905-
inputs=case.inputs,
906-
metadata=case.metadata,
907-
expected_output=case.expected_output,
908-
) as case_span:
909-
t0 = time.time()
910-
scoring_context = await _run_task(task, case)
911-
912-
case_span.set_attribute('output', scoring_context.output)
913-
case_span.set_attribute('task_duration', scoring_context.duration)
914-
case_span.set_attribute('metrics', scoring_context.metrics)
915-
case_span.set_attribute('attributes', scoring_context.attributes)
916-
917-
evaluators = case.evaluators + dataset_evaluators
918-
evaluator_outputs: list[EvaluationResult] = []
919-
if evaluators:
920-
evaluator_outputs_by_task = await task_group_gather(
921-
[lambda ev=ev: run_evaluator(ev, scoring_context) for ev in evaluators]
922-
)
923-
evaluator_outputs += [out for outputs in evaluator_outputs_by_task for out in outputs]
924-
925-
assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
926-
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
927-
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
928-
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
929-
930-
context = case_span.context
931-
if context is None: # pragma: no cover
932-
trace_id = ''
933-
span_id = ''
934-
else:
935-
trace_id = f'{context.trace_id:032x}'
936-
span_id = f'{context.span_id:016x}'
937-
fallback_duration = time.time() - t0
938-
939-
return ReportCase[InputsT, OutputT, MetadataT](
940-
name=report_case_name,
941-
inputs=case.inputs,
942-
metadata=case.metadata,
943-
expected_output=case.expected_output,
944-
output=scoring_context.output,
945-
metrics=scoring_context.metrics,
946-
attributes=scoring_context.attributes,
947-
scores=scores,
948-
labels=labels,
949-
assertions=assertions,
950-
task_duration=scoring_context.duration,
951-
total_duration=_get_span_duration(case_span, fallback_duration),
952-
trace_id=trace_id,
953-
span_id=span_id,
954-
)
927+
trace_id = ''
928+
span_id = ''
929+
try:
930+
with _logfire.span(
931+
'case: {case_name}',
932+
task_name=get_unwrapped_function_name(task),
933+
case_name=report_case_name,
934+
inputs=case.inputs,
935+
metadata=case.metadata,
936+
expected_output=case.expected_output,
937+
) as case_span:
938+
context = case_span.context
939+
if context is not None: # pragma: no cover
940+
trace_id = f'{context.trace_id:032x}'
941+
span_id = f'{context.span_id:016x}'
942+
943+
t0 = time.time()
944+
scoring_context = await _run_task(task, case, retry)
945+
946+
case_span.set_attribute('output', scoring_context.output)
947+
case_span.set_attribute('task_duration', scoring_context.duration)
948+
case_span.set_attribute('metrics', scoring_context.metrics)
949+
case_span.set_attribute('attributes', scoring_context.attributes)
950+
951+
evaluators = case.evaluators + dataset_evaluators
952+
evaluator_outputs: list[EvaluationResult] = []
953+
evaluator_failures: list[EvaluatorFailure] = []
954+
if evaluators:
955+
evaluator_outputs_by_task = await task_group_gather(
956+
[lambda ev=ev: run_evaluator(ev, scoring_context) for ev in evaluators]
957+
)
958+
flattened = [out for outputs in evaluator_outputs_by_task for out in outputs]
959+
evaluator_outputs += [o for o in flattened if not isinstance(o, EvaluatorFailure)]
960+
evaluator_failures += [o for o in flattened if isinstance(o, EvaluatorFailure)]
961+
962+
assertions, scores, labels = _group_evaluator_outputs_by_type(evaluator_outputs)
963+
case_span.set_attribute('assertions', _evaluation_results_adapter.dump_python(assertions))
964+
case_span.set_attribute('scores', _evaluation_results_adapter.dump_python(scores))
965+
case_span.set_attribute('labels', _evaluation_results_adapter.dump_python(labels))
966+
967+
fallback_duration = time.time() - t0
968+
969+
return ReportCase[InputsT, OutputT, MetadataT](
970+
name=report_case_name,
971+
inputs=case.inputs,
972+
metadata=case.metadata,
973+
expected_output=case.expected_output,
974+
output=scoring_context.output,
975+
metrics=scoring_context.metrics,
976+
attributes=scoring_context.attributes,
977+
scores=scores,
978+
labels=labels,
979+
assertions=assertions,
980+
task_duration=scoring_context.duration,
981+
total_duration=_get_span_duration(case_span, fallback_duration),
982+
trace_id=trace_id,
983+
span_id=span_id,
984+
evaluator_failures=evaluator_failures,
985+
)
986+
except Exception as exc:
987+
return ReportCaseFailure[InputsT, OutputT, MetadataT](
988+
name=report_case_name,
989+
inputs=case.inputs,
990+
metadata=case.metadata,
991+
expected_output=case.expected_output,
992+
error_msg=f'{type(exc).__name__}: {exc}',
993+
trace_id=trace_id,
994+
span_id=span_id,
995+
)
955996

956997

957998
_evaluation_results_adapter = TypeAdapter(Mapping[str, EvaluationResult])

pydantic_evals/pydantic_evals/evaluators/_run_evaluator.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@
1010
from typing_extensions import TypeVar
1111

1212
from .context import EvaluatorContext
13-
from .evaluator import EvaluationReason, EvaluationResult, EvaluationScalar, Evaluator, EvaluatorOutput
13+
from .evaluator import (
14+
EvaluationReason,
15+
EvaluationResult,
16+
EvaluationScalar,
17+
Evaluator,
18+
EvaluatorFailure,
19+
EvaluatorOutput,
20+
)
1421

1522
InputsT = TypeVar('InputsT', default=Any, contravariant=True)
1623
OutputT = TypeVar('OutputT', default=Any, contravariant=True)
@@ -19,7 +26,7 @@
1926

2027
async def run_evaluator(
2128
evaluator: Evaluator[InputsT, OutputT, MetadataT], ctx: EvaluatorContext[InputsT, OutputT, MetadataT]
22-
) -> list[EvaluationResult]:
29+
) -> list[EvaluationResult] | list[EvaluatorFailure]:
2330
"""Run an evaluator and return the results.
2431
2532
This function runs an evaluator on the given context and processes the results into
@@ -35,22 +42,29 @@ async def run_evaluator(
3542
Raises:
3643
ValueError: If the evaluator returns a value of an invalid type.
3744
"""
38-
raw_results = await evaluator.evaluate_async(ctx)
39-
4045
try:
41-
results = _EVALUATOR_OUTPUT_ADAPTER.validate_python(raw_results)
42-
except ValidationError as e:
43-
raise ValueError(f'{evaluator!r}.evaluate returned a value of an invalid type: {raw_results!r}.') from e
44-
45-
results = _convert_to_mapping(results, scalar_name=evaluator.get_default_evaluation_name())
46-
47-
details: list[EvaluationResult] = []
48-
for name, result in results.items():
49-
if not isinstance(result, EvaluationReason):
50-
result = EvaluationReason(value=result)
51-
details.append(EvaluationResult(name=name, value=result.value, reason=result.reason, source=evaluator))
52-
53-
return details
46+
raw_results = await evaluator.evaluate_async(ctx)
47+
48+
try:
49+
results = _EVALUATOR_OUTPUT_ADAPTER.validate_python(raw_results)
50+
except ValidationError as e:
51+
raise ValueError(f'{evaluator!r}.evaluate returned a value of an invalid type: {raw_results!r}.') from e
52+
53+
results = _convert_to_mapping(results, scalar_name=evaluator.get_default_evaluation_name())
54+
55+
details: list[EvaluationResult] = []
56+
for name, result in results.items():
57+
if not isinstance(result, EvaluationReason):
58+
result = EvaluationReason(value=result)
59+
details.append(EvaluationResult(name=name, value=result.value, reason=result.reason, source=evaluator))
60+
61+
return details
62+
except Exception as e:
63+
return [
64+
EvaluatorFailure(
65+
name=evaluator.get_default_evaluation_name(), error_msg=f'{type(e).__name__}: {e}', source=evaluator
66+
)
67+
]
5468

5569

5670
_EVALUATOR_OUTPUT_ADAPTER = TypeAdapter[EvaluatorOutput](EvaluatorOutput)

0 commit comments

Comments
 (0)