Skip to content

Adding named params to openai activity configurations #917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions temporalio/contrib/openai_agents/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ from datetime import timedelta

from temporalio.client import Client
from temporalio.contrib.openai_agents.invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents.open_ai_data_converter import open_ai_data_converter
from temporalio.contrib.openai_agents.temporal_openai_agents import set_open_ai_agent_temporal_overrides
from temporalio.worker import Worker
Expand All @@ -105,9 +106,10 @@ from hello_world_workflow import HelloWorldAgent
async def worker_main():
# Configure the OpenAI Agents SDK to use Temporal activities for LLM API calls
# and for tool calls.
with set_open_ai_agent_temporal_overrides(
model_params = ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=10)
):
)
with set_open_ai_agent_temporal_overrides(model_params):
# Create a Temporal client connected to server at the given address
# Use the OpenAI data converter to ensure proper serialization/deserialization
client = await Client.connect(
Expand Down
16 changes: 12 additions & 4 deletions temporalio/contrib/openai_agents/_openai_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import replace
from typing import Union
from datetime import timedelta
from typing import Optional, Union

from agents import (
Agent,
Expand All @@ -13,7 +14,10 @@
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner

from temporalio import workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
from temporalio.workflow import ActivityCancellationType, VersioningIntent


class TemporalOpenAIRunner(AgentRunner):
Expand All @@ -23,10 +27,10 @@ class TemporalOpenAIRunner(AgentRunner):

"""

def __init__(self, **kwargs) -> None:
def __init__(self, model_params: ModelActivityParameters) -> None:
"""Initialize the Temporal OpenAI Runner."""
self._runner = DEFAULT_AGENT_RUNNER or AgentRunner()
self.kwargs = kwargs
self.model_params = model_params

async def run(
self,
Expand Down Expand Up @@ -56,7 +60,11 @@ async def run(
"Temporal workflows require a model name to be a string in the run config."
)
updated_run_config = replace(
run_config, model=_TemporalModelStub(run_config.model, **self.kwargs)
run_config,
model=_TemporalModelStub(
run_config.model,
model_params=self.model_params,
),
)

with workflow.unsafe.imports_passed_through():
Expand Down
27 changes: 23 additions & 4 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

import logging
from datetime import timedelta
from typing import Optional

from temporalio import workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
from temporalio.workflow import ActivityCancellationType, VersioningIntent

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -41,9 +46,14 @@
class _TemporalModelStub(Model):
"""A stub that allows invoking models as Temporal activities."""

def __init__(self, model_name: Optional[str], **kwargs) -> None:
def __init__(
self,
model_name: Optional[str],
*,
model_params: ModelActivityParameters,
) -> None:
self.model_name = model_name
self.kwargs = kwargs
self.model_params = model_params

async def get_response(
self,
Expand Down Expand Up @@ -141,11 +151,20 @@ def make_tool_info(tool: Tool) -> ToolInput:
previous_response_id=previous_response_id,
prompt=prompt,
)

return await workflow.execute_activity_method(
ModelActivity.invoke_model_activity,
activity_input,
summary=get_summary(input),
**self.kwargs,
summary=self.model_params.summary_override or get_summary(input),
task_queue=self.model_params.task_queue,
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
start_to_close_timeout=self.model_params.start_to_close_timeout,
heartbeat_timeout=self.model_params.heartbeat_timeout,
retry_policy=self.model_params.retry_policy,
cancellation_type=self.model_params.cancellation_type,
versioning_intent=self.model_params.versioning_intent,
priority=self.model_params.priority,
)

def stream_response(
Expand Down
48 changes: 48 additions & 0 deletions temporalio/contrib/openai_agents/model_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Parameters for configuring Temporal activity execution for model calls."""

from dataclasses import dataclass
from datetime import timedelta
from typing import Optional

from temporalio.common import Priority, RetryPolicy
from temporalio.workflow import ActivityCancellationType, VersioningIntent


@dataclass
class ModelActivityParameters:
"""Parameters for configuring Temporal activity execution for model calls.

This class encapsulates all the parameters that can be used to configure
how Temporal activities are executed when making model calls through the
OpenAI Agents integration.
"""

task_queue: Optional[str] = None
"""Specific task queue to use for model activities."""

schedule_to_close_timeout: Optional[timedelta] = None
"""Maximum time from scheduling to completion."""

schedule_to_start_timeout: Optional[timedelta] = None
"""Maximum time from scheduling to starting."""

start_to_close_timeout: Optional[timedelta] = None
"""Maximum time for the activity to complete."""

heartbeat_timeout: Optional[timedelta] = None
"""Maximum time between heartbeats."""

retry_policy: Optional[RetryPolicy] = None
"""Policy for retrying failed activities."""

cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL
"""How the activity handles cancellation."""

versioning_intent: Optional[VersioningIntent] = None
"""Versioning intent for the activity."""

summary_override: Optional[str] = None
"""Summary for the activity execution."""

priority: Priority = Priority.default
"""Priority for the activity execution."""
40 changes: 19 additions & 21 deletions temporalio/contrib/openai_agents/temporal_openai_agents.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
"""Initialize Temporal OpenAI Agents overrides."""

from contextlib import contextmanager
from datetime import timedelta
from typing import Optional

from agents import set_trace_provider
from agents.run import AgentRunner, get_default_agent_runner, set_default_agent_runner
from agents.tracing import TraceProvider, get_trace_provider
from agents.run import get_default_agent_runner, set_default_agent_runner
from agents.tracing import get_trace_provider
from agents.tracing.provider import DefaultTraceProvider

from temporalio.common import Priority, RetryPolicy
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
from temporalio.contrib.openai_agents._temporal_trace_provider import (
TemporalTraceProvider,
)
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
from temporalio.workflow import ActivityCancellationType, VersioningIntent


@contextmanager
def set_open_ai_agent_temporal_overrides(**kwargs):
def set_open_ai_agent_temporal_overrides(
model_params: ModelActivityParameters,
):
"""Configure Temporal-specific overrides for OpenAI agents.

.. warning::
Expand All @@ -33,34 +39,26 @@ def set_open_ai_agent_temporal_overrides(**kwargs):
3. Restoring previous settings when the context exits

Args:
**kwargs: Additional arguments to pass to the TemporalOpenAIRunner constructor.
These arguments are forwarded to workflow.execute_activity_method when
executing model calls. Common options include:
- start_to_close_timeout: Maximum time for the activity to complete
- schedule_to_close_timeout: Maximum time from scheduling to completion
- retry_policy: Policy for retrying failed activities
- task_queue: Specific task queue to use for model activities

Example usage:
with set_open_ai_agent_temporal_overrides(
start_to_close_timeout=timedelta(seconds=30),
retry_policy=RetryPolicy(maximum_attempts=3)
):
# Initialize Temporal client and worker here
client = await Client.connect("localhost:7233")
worker = Worker(client, task_queue="my-task-queue")
await worker.run()
model_params: Configuration parameters for Temporal activity execution of model calls.

Returns:
A context manager that yields the configured TemporalTraceProvider.

"""
if (
not model_params.start_to_close_timeout
and not model_params.schedule_to_close_timeout
):
raise ValueError(
"Activity must have start_to_close_timeout or schedule_to_close_timeout"
)

previous_runner = get_default_agent_runner()
previous_trace_provider = get_trace_provider()
provider = TemporalTraceProvider()

try:
set_default_agent_runner(TemporalOpenAIRunner(**kwargs))
set_default_agent_runner(TemporalOpenAIRunner(model_params))
set_trace_provider(provider)
yield provider
finally:
Expand Down
45 changes: 31 additions & 14 deletions temporalio/contrib/openai_agents/temporal_tools.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
"""Support for using Temporal activities as OpenAI agents tools."""

from typing import Any, Callable
from datetime import timedelta
from typing import Any, Callable, Optional

from temporalio import activity, workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.exceptions import ApplicationError
from temporalio.workflow import unsafe
from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe

with unsafe.imports_passed_through():
from agents import FunctionTool, RunContextWrapper, Tool
from agents.function_schema import function_schema


def activity_as_tool(fn: Callable, **kwargs) -> Tool:
def activity_as_tool(
fn: Callable,
*,
task_queue: Optional[str] = None,
schedule_to_close_timeout: Optional[timedelta] = None,
schedule_to_start_timeout: Optional[timedelta] = None,
start_to_close_timeout: Optional[timedelta] = None,
heartbeat_timeout: Optional[timedelta] = None,
retry_policy: Optional[RetryPolicy] = None,
cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL,
activity_id: Optional[str] = None,
versioning_intent: Optional[VersioningIntent] = None,
summary: Optional[str] = None,
priority: Priority = Priority.default,
) -> Tool:
"""Convert a single Temporal activity function to an OpenAI agent tool.

.. warning::
Expand All @@ -25,16 +41,7 @@ def activity_as_tool(fn: Callable, **kwargs) -> Tool:

Args:
fn: A Temporal activity function to convert to a tool.
**kwargs: Additional arguments to pass to workflow.execute_activity.
These arguments configure how the activity is executed. Common options include:
- start_to_close_timeout: Maximum time for the activity to complete
- schedule_to_close_timeout: Maximum time from scheduling to completion
- schedule_to_start_timeout: Maximum time from scheduling to starting
- heartbeat_timeout: Maximum time between heartbeats
- retry_policy: Policy for retrying failed activities
- task_queue: Specific task queue to use for this activity
- cancellation_type: How the activity handles cancellation
- workflow_id_reuse_policy: Policy for workflow ID reuse
For other arguments, refer to :py:mod:`workflow` :py:meth:`start_activity`

Returns:
An OpenAI agent tool that wraps the provided activity.
Expand Down Expand Up @@ -69,7 +76,17 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
await workflow.execute_activity(
fn,
input,
**kwargs,
task_queue=task_queue,
schedule_to_close_timeout=schedule_to_close_timeout,
schedule_to_start_timeout=schedule_to_start_timeout,
start_to_close_timeout=start_to_close_timeout,
heartbeat_timeout=heartbeat_timeout,
retry_policy=retry_policy,
cancellation_type=cancellation_type,
activity_id=activity_id,
versioning_intent=versioning_intent,
summary=summary,
priority=priority,
)
)
except Exception:
Expand Down
26 changes: 11 additions & 15 deletions tests/contrib/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from temporalio.contrib.openai_agents.invoke_model_activity import (
ModelActivity,
)
from temporalio.contrib.openai_agents.model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents.open_ai_data_converter import (
open_ai_data_converter,
)
Expand Down Expand Up @@ -144,9 +145,8 @@ async def test_hello_world_agent(client: Client):
new_config["data_converter"] = open_ai_data_converter
client = Client(**new_config)

with set_open_ai_agent_temporal_overrides(
start_to_close_timeout=timedelta(seconds=10)
):
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
with set_open_ai_agent_temporal_overrides(model_params):
model_activity = ModelActivity(
TestProvider(
TestHelloModel( # type: ignore
Expand Down Expand Up @@ -242,9 +242,8 @@ async def test_tool_workflow(client: Client):
new_config["data_converter"] = open_ai_data_converter
client = Client(**new_config)

with set_open_ai_agent_temporal_overrides(
start_to_close_timeout=timedelta(seconds=10)
):
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
with set_open_ai_agent_temporal_overrides(model_params):
model_activity = ModelActivity(
TestProvider(
TestWeatherModel( # type: ignore
Expand Down Expand Up @@ -464,9 +463,8 @@ async def test_research_workflow(client: Client):
global response_index
response_index = 0

with set_open_ai_agent_temporal_overrides(
start_to_close_timeout=timedelta(seconds=10)
):
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
with set_open_ai_agent_temporal_overrides(model_params):
model_activity = ModelActivity(
TestProvider(
TestResearchModel( # type: ignore
Expand Down Expand Up @@ -675,9 +673,8 @@ async def test_agents_as_tools_workflow(client: Client):
new_config["data_converter"] = open_ai_data_converter
client = Client(**new_config)

with set_open_ai_agent_temporal_overrides(
start_to_close_timeout=timedelta(seconds=10)
):
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
with set_open_ai_agent_temporal_overrides(model_params):
model_activity = ModelActivity(
TestProvider(
AgentAsToolsModel( # type: ignore
Expand Down Expand Up @@ -1033,9 +1030,8 @@ async def test_customer_service_workflow(client: Client):

questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"]

with set_open_ai_agent_temporal_overrides(
start_to_close_timeout=timedelta(seconds=10)
):
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=10))
with set_open_ai_agent_temporal_overrides(model_params):
model_activity = ModelActivity(
TestProvider(
CustomerServiceModel( # type: ignore
Expand Down