diff --git a/temporalio/contrib/openai_agents/_temporal_trace_provider.py b/temporalio/contrib/openai_agents/_temporal_trace_provider.py index 4637afbe1..1d9b09866 100644 --- a/temporalio/contrib/openai_agents/_temporal_trace_provider.py +++ b/temporalio/contrib/openai_agents/_temporal_trace_provider.py @@ -7,7 +7,10 @@ from agents.tracing import ( get_trace_provider, ) -from agents.tracing.provider import DefaultTraceProvider +from agents.tracing.provider import ( + DefaultTraceProvider, + SynchronousMultiTracingProcessor, +) from agents.tracing.spans import Span from temporalio import workflow @@ -72,10 +75,19 @@ def activity_span( ) -class _TemporalTracingProcessor(TracingProcessor): - def __init__(self, impl: TracingProcessor): +class _TemporalTracingProcessor(SynchronousMultiTracingProcessor): + def __init__( + self, impl: SynchronousMultiTracingProcessor, auto_close_in_workflows: bool + ): super().__init__() self._impl = impl + self._auto_close_in_workflows = auto_close_in_workflows + + def add_tracing_processor(self, tracing_processor: TracingProcessor): + self._impl.add_tracing_processor(tracing_processor) + + def set_processors(self, processors: list[TracingProcessor]): + self._impl.set_processors(processors) def on_trace_start(self, trace: Trace) -> None: if workflow.in_workflow() and workflow.unsafe.is_replaying(): @@ -83,11 +95,15 @@ def on_trace_start(self, trace: Trace) -> None: return self._impl.on_trace_start(trace) + if self._auto_close_in_workflows and workflow.in_workflow(): + self._impl.on_trace_end(trace) def on_trace_end(self, trace: Trace) -> None: if workflow.in_workflow() and workflow.unsafe.is_replaying(): # In replay mode, don't report return + if self._auto_close_in_workflows and workflow.in_workflow(): + return self._impl.on_trace_end(trace) @@ -97,11 +113,16 @@ def on_span_start(self, span: Span[Any]) -> None: return self._impl.on_span_start(span) + if self._auto_close_in_workflows and workflow.in_workflow(): + self._impl.on_span_end(span) def on_span_end(self, span: Span[Any]) -> None: if workflow.in_workflow() and workflow.unsafe.is_replaying(): # In replay mode, don't report return + if self._auto_close_in_workflows and workflow.in_workflow(): + return + self._impl.on_span_end(span) def shutdown(self) -> None: @@ -114,12 +135,13 @@ def force_flush(self) -> None: class TemporalTraceProvider(DefaultTraceProvider): """A trace provider that integrates with Temporal workflows.""" - def __init__(self): + def __init__(self, auto_close_in_workflows: bool = False): """Initialize the TemporalTraceProvider.""" super().__init__() self._original_provider = cast(DefaultTraceProvider, get_trace_provider()) - self._multi_processor = _TemporalTracingProcessor( # type: ignore[assignment] - self._original_provider._multi_processor + self._multi_processor = _TemporalTracingProcessor( + self._original_provider._multi_processor, + auto_close_in_workflows, ) def time_iso(self) -> str: diff --git a/temporalio/contrib/openai_agents/temporal_openai_agents.py b/temporalio/contrib/openai_agents/temporal_openai_agents.py index 871b7b577..9afb57cc5 100644 --- a/temporalio/contrib/openai_agents/temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/temporal_openai_agents.py @@ -21,6 +21,7 @@ @contextmanager def set_open_ai_agent_temporal_overrides( model_params: ModelActivityParameters, + auto_close_tracing_in_workflows: bool = False, ): """Configure Temporal-specific overrides for OpenAI agents. @@ -55,7 +56,9 @@ def set_open_ai_agent_temporal_overrides( previous_runner = get_default_agent_runner() previous_trace_provider = get_trace_provider() - provider = TemporalTraceProvider() + provider = TemporalTraceProvider( + auto_close_in_workflows=auto_close_tracing_in_workflows + ) try: set_default_agent_runner(TemporalOpenAIRunner(model_params)) diff --git a/temporalio/contrib/openai_agents/trace_interceptor.py b/temporalio/contrib/openai_agents/trace_interceptor.py index 32ed7b7e7..8a791b73c 100644 --- a/temporalio/contrib/openai_agents/trace_interceptor.py +++ b/temporalio/contrib/openai_agents/trace_interceptor.py @@ -60,16 +60,6 @@ def context_from_header( if span_info is None: yield else: - span = SpanImpl( - trace_id=str(span_info["traceId"]), - span_id=span_info["spanId"], - parent_id=None, - span_data=CustomSpanData( - name="Parent Temporal Span", - data={}, - ), - processor=cast(DefaultTraceProvider, get_trace_provider())._multi_processor, - ) workflow_type = ( activity.info().workflow_type if activity.in_activity() @@ -94,11 +84,11 @@ def context_from_header( span_info["traceName"], trace_id=span_info["traceId"], metadata=metadata, - ): - with custom_span(name=span_name, parent=span, data=data): + ) as t: + with custom_span(name=span_name, parent=t, data=data): yield else: - with custom_span(name=span_name, parent=span, data=data): + with custom_span(name=span_name, parent=None, data=data): yield