Skip to content

Add async support for dspy.Evaluate #8504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 83 additions & 39 deletions dspy/evaluate/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
if TYPE_CHECKING:
import pandas as pd


import tqdm

import dspy
Expand Down Expand Up @@ -51,6 +52,7 @@ class EvaluationResult(Prediction):
- score: An float value (e.g., 67.30) representing the overall performance
- results: a list of (example, prediction, score) tuples for each example in devset
"""

def __init__(self, score: float, results: list[tuple["dspy.Example", "dspy.Example", Any]]):
super().__init__(score=score, results=results)

Expand Down Expand Up @@ -126,71 +128,113 @@ def __call__(

Returns:
The evaluation results are returned as a dspy.EvaluationResult object containing the following attributes:

- score: A float percentage score (e.g., 67.30) representing overall performance

- results: a list of (example, prediction, score) tuples for each example in devset
"""
metric = metric if metric is not None else self.metric
devset = devset if devset is not None else self.devset
num_threads = num_threads if num_threads is not None else self.num_threads
display_progress = display_progress if display_progress is not None else self.display_progress
display_table = display_table if display_table is not None else self.display_table

metric, devset, num_threads, display_progress, display_table = self._resolve_call_args(
metric, devset, num_threads, display_progress, display_table
)
if callback_metadata:
logger.debug(f"Evaluate is called with callback metadata: {callback_metadata}")

tqdm.tqdm._instances.clear()
results = self._execute_with_multithreading(program, metric, devset, num_threads, display_progress)
return self._process_evaluate_result(devset, results, metric, display_table)

executor = ParallelExecutor(
num_threads=num_threads,
disable_progress_bar=not display_progress,
max_errors=(
self.max_errors
if self.max_errors is not None
else dspy.settings.max_errors
),
provide_traceback=self.provide_traceback,
compare_results=True,
@with_callbacks
async def acall(
self,
program: "dspy.Module",
metric: Callable | None = None,
devset: list["dspy.Example"] | None = None,
num_threads: int | None = None,
display_progress: bool | None = None,
display_table: bool | int | None = None,
callback_metadata: dict[str, Any] | None = None,
) -> EvaluationResult:
"""Async version of `Evaluate.__call__`."""
metric, devset, num_threads, display_progress, display_table = self._resolve_call_args(
metric, devset, num_threads, display_progress, display_table
)
if callback_metadata:
logger.debug(f"Evaluate.acall is called with callback metadata: {callback_metadata}")
tqdm.tqdm._instances.clear()
results = await self._execute_with_event_loop(program, metric, devset, num_threads, display_progress)
return self._process_evaluate_result(devset, results, metric, display_table)

def _resolve_call_args(self, metric, devset, num_threads, display_progress, display_table):
return (
metric or self.metric,
devset or self.devset,
num_threads or self.num_threads,
display_progress or self.display_progress,
display_table or self.display_table,
)

def process_item(example):
prediction = program(**example.inputs())
score = metric(example, prediction)

# Increment assert and suggest failures to program's attributes
if hasattr(program, "_assert_failures"):
program._assert_failures += dspy.settings.get("assert_failures")
if hasattr(program, "_suggest_failures"):
program._suggest_failures += dspy.settings.get("suggest_failures")

return prediction, score

results = executor.execute(process_item, devset)
def _process_evaluate_result(self, devset, results, metric, display_table):
assert len(devset) == len(results)

results = [((dspy.Prediction(), self.failure_score) if r is None else r) for r in results]
results = [(example, prediction, score) for example, (prediction, score) in zip(devset, results, strict=False)]
ncorrect, ntotal = sum(score for *_, score in results), len(devset)

logger.info(f"Average Metric: {ncorrect} / {ntotal} ({round(100 * ncorrect / ntotal, 1)}%)")

if display_table:
if importlib.util.find_spec("pandas") is not None:
# Rename the 'correct' column to the name of the metric object
metric_name = metric.__name__ if isinstance(metric, types.FunctionType) else metric.__class__.__name__
# Construct a pandas DataFrame from the results
result_df = self._construct_result_table(results, metric_name)

self._display_result_table(result_df, display_table, metric_name)
else:
logger.warning("Skipping table display since `pandas` is not installed.")

return EvaluationResult(
score=round(100 * ncorrect / ntotal, 2),
results=results,
)

def _execute_with_multithreading(
self,
program: "dspy.Module",
metric: Callable,
devset: list["dspy.Example"],
num_threads: int,
disable_progress_bar: bool,
):
executor = ParallelExecutor(
num_threads=num_threads,
disable_progress_bar=disable_progress_bar,
max_errors=(self.max_errors or dspy.settings.max_errors),
provide_traceback=self.provide_traceback,
compare_results=True,
)

def process_item(example):
prediction = program(**example.inputs())
score = metric(example, prediction)
return prediction, score

return executor.execute(process_item, devset)

async def _execute_with_event_loop(
self,
program: "dspy.Module",
metric: Callable,
devset: list["dspy.Example"],
num_threads: int,
disable_progress_bar: bool,
):
executor = ParallelExecutor(
num_threads=num_threads,
disable_progress_bar=disable_progress_bar,
max_errors=(self.max_errors or dspy.settings.max_errors),
provide_traceback=self.provide_traceback,
compare_results=True,
)

async def process_item(example):
prediction = await program.acall(**example.inputs())
score = metric(example, prediction)
return prediction, score

return await executor.aexecute(process_item, devset)

def _construct_result_table(
self, results: list[tuple["dspy.Example", "dspy.Example", Any]], metric_name: str
Expand Down
100 changes: 88 additions & 12 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import contextlib
import copy
import logging
Expand Down Expand Up @@ -42,29 +43,60 @@ def __init__(
self.error_lock = threading.Lock()
self.cancel_jobs = threading.Event()

self.error_lock_async = asyncio.Lock()
self.cancel_jobs_async = asyncio.Event()

def execute(self, function, data):
tqdm.tqdm._instances.clear()
wrapped = self._wrap_function(function)
wrapped = self._wrap_function(function, async_mode=False)
return self._execute_parallel(wrapped, data)

def _wrap_function(self, user_function):
def safe_func(item):
async def aexecute(self, function, data):
tqdm.tqdm._instances.clear()
wrapped = self._wrap_function(function, async_mode=True)
return await self._execute_parallel_async(wrapped, data)

def _handle_error(self, item, e):
with self.error_lock:
self.error_count += 1
if self.error_count >= self.max_errors:
self.cancel_jobs.set()
if self.provide_traceback:
logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}")
else:
logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.")

async def _handle_error_async(self, item, e):
async with self.error_lock_async:
self.error_count += 1
if self.error_count >= self.max_errors:
self.cancel_jobs_async.set()
if self.provide_traceback:
logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}")

def _wrap_function(self, user_function, async_mode=False):
async def _async_safe_func(item):
if self.cancel_jobs.is_set():
return None
try:
return await user_function(item)
except Exception as e:
await self._handle_error_async(item, e)
return None

def _sync_safe_func(item):
if self.cancel_jobs.is_set():
return None
try:
return user_function(item)
except Exception as e:
with self.error_lock:
self.error_count += 1
if self.error_count >= self.max_errors:
self.cancel_jobs.set()
if self.provide_traceback:
logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}")
else:
logger.error(f"Error for {item}: {e}. Set `provide_traceback=True` for traceback.")
self._handle_error(item, e)
return None

return safe_func
if async_mode:
return _async_safe_func
else:
return _sync_safe_func

def _execute_parallel(self, function, data):
results = [None] * len(data)
Expand Down Expand Up @@ -204,6 +236,50 @@ def all_done():

return results

async def _execute_parallel_async(self, function, data):
queue = asyncio.Queue()
results = [None] * len(data)
for i, example in enumerate(data):
await queue.put((i, example))

for _ in range(self.num_threads):
# Add a sentinel value to indicate that the worker should exit
await queue.put((-1, None))

# Create tqdm progress bar
pbar = tqdm.tqdm(total=len(data), dynamic_ncols=True)

async def worker():
while True:
if self.cancel_jobs_async.is_set():
break
index, example = await queue.get()
if index == -1:
break
function_outputs = await function(example)
results[index] = function_outputs

if self.compare_results:
vals = [r[-1] for r in results if r is not None]
self._update_progress(pbar, sum(vals), len(vals))
else:
self._update_progress(
pbar,
len([r for r in results if r is not None]),
len(data),
)

queue.task_done()

workers = [asyncio.create_task(worker()) for _ in range(self.num_threads)]
await asyncio.gather(*workers)
pbar.close()
if self.cancel_jobs_async.is_set():
logger.warning("Execution cancelled due to errors or interruption.")
raise Exception("Execution cancelled due to errors or interruption.")

return results

def _update_progress(self, pbar, nresults, ntotal):
if self.compare_results:
pct = round(100 * nresults / ntotal, 1) if ntotal else 0
Expand Down
32 changes: 32 additions & 0 deletions tests/evaluate/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import signal
import threading
from unittest.mock import patch
Expand Down Expand Up @@ -57,6 +58,7 @@ def test_evaluate_call():
@pytest.mark.extra
def test_construct_result_df():
import pandas as pd

devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
ev = Evaluate(
devset=devset,
Expand Down Expand Up @@ -253,6 +255,36 @@ def on_evaluate_end(
assert callback.end_call_outputs.score == 100.0
assert callback.end_call_count == 1


def test_evaluation_result_repr():
result = EvaluationResult(score=100.0, results=[(new_example("What is 1+1?", "2"), {"answer": "2"}, 100.0)])
assert repr(result) == "EvaluationResult(score=100.0, results=<list of 1 results>)"


@pytest.mark.asyncio
async def test_async_evaluate_call():
class MyProgram(dspy.Module):
async def acall(self, question):
# Simulate network delay
await asyncio.sleep(0.05)
# Just echo the input as the prediction
return dspy.Prediction(answer=question)

def dummy_metric(gold, pred, traces=None):
return gold.answer == pred.answer

devset = [new_example(f"placeholder_{i}", f"placeholder_{i}") for i in range(20)]

evaluator = Evaluate(devset=devset, metric=dummy_metric, num_threads=10, display_progress=False)

result = await evaluator.acall(MyProgram())
assert isinstance(result.score, float)
assert result.score == 100.0 # Both answers should match
assert len(result.results) == 20

# Check the results are in the correct order
for i, (example, prediction, score) in enumerate(result.results):
assert isinstance(prediction, dspy.Prediction)
assert score == 1
assert example.question == f"placeholder_{i}"
assert prediction.answer == f"placeholder_{i}"
Loading