Skip to content

Provide client in activity context #740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1257,6 +1257,7 @@ calls in the `temporalio.activity` package make use of it. Specifically:

* `in_activity()` - Whether an activity context is present
* `info()` - Returns the immutable info of the currently running activity
* `client()` - Returns the Temporal client used by this worker. Only available in `async def` activities.
* `heartbeat(*details)` - Record a heartbeat
* `is_cancelled()` - Whether a cancellation has been requested on this activity
* `wait_for_cancelled()` - `async` call to wait for cancellation request
Expand Down
37 changes: 34 additions & 3 deletions temporalio/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterator,
Expand All @@ -42,6 +43,9 @@

from .types import CallableType

if TYPE_CHECKING:
from temporalio.client import Client


@overload
def defn(fn: CallableType) -> CallableType: ...
Expand Down Expand Up @@ -179,6 +183,7 @@ class _Context:
temporalio.converter.PayloadConverter,
]
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
client: Optional[Client]
cancellation_details: _ActivityCancellationDetailsHolder
_logger_details: Optional[Mapping[str, Any]] = None
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
Expand Down Expand Up @@ -271,13 +276,37 @@ def wait_sync(self, timeout: Optional[float] = None) -> None:
self.thread_event.wait(timeout)


def client() -> Client:
"""Return a Temporal Client for use in the current activity.

The client is only available in `async def` activities.

In tests it is not available automatically, but you can pass a client when creating a
:py:class:`temporalio.testing.ActivityEnvironment`.

Returns:
:py:class:`temporalio.client.Client` for use in the current activity.

Raises:
RuntimeError: When the client is not available.
"""
client = _Context.current().client
if not client:
raise RuntimeError(
"No client available. The client is only available in `async def` "
"activities; not in `def` activities. In tests you can pass a "
"client when creating ActivityEnvironment."
)
return client


def in_activity() -> bool:
"""Whether the current code is inside an activity.

Returns:
True if in an activity, False otherwise.
"""
return not _current_context.get(None) is None
return _current_context.get(None) is not None


def info() -> Info:
Expand Down Expand Up @@ -574,8 +603,10 @@ def _apply_to_callable(
fn=fn,
# iscoroutinefunction does not return true for async __call__
# TODO(cretz): Why can't MyPy handle this?
is_async=inspect.iscoroutinefunction(fn)
or inspect.iscoroutinefunction(fn.__call__), # type: ignore
is_async=(
inspect.iscoroutinefunction(fn)
or inspect.iscoroutinefunction(fn.__call__) # type: ignore
),
no_thread_cancel_exception=no_thread_cancel_exception,
),
)
Expand Down
20 changes: 14 additions & 6 deletions temporalio/testing/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import temporalio.converter
import temporalio.exceptions
import temporalio.worker._activity
from temporalio.client import Client

_Params = ParamSpec("_Params")
_Return = TypeVar("_Return")
Expand Down Expand Up @@ -63,7 +64,7 @@ class ActivityEnvironment:
take effect. Default is noop.
"""

def __init__(self) -> None:
def __init__(self, client: Optional[Client] = None) -> None:
"""Create an ActivityEnvironment for running activity code."""
self.info = _default_info
self.on_heartbeat: Callable[..., None] = lambda *args: None
Expand All @@ -74,6 +75,7 @@ def __init__(self) -> None:
self._cancelled = False
self._worker_shutdown = False
self._activities: Set[_Activity] = set()
self._client = client
self._cancellation_details = (
temporalio.activity._ActivityCancellationDetailsHolder()
)
Expand Down Expand Up @@ -128,18 +130,21 @@ def run(
The callable's result.
"""
# Create an activity and run it
return _Activity(self, fn).run(*args, **kwargs)
return _Activity(self, fn, self._client).run(*args, **kwargs)


class _Activity:
def __init__(
self,
env: ActivityEnvironment,
fn: Callable,
client: Optional[Client],
) -> None:
self.env = env
self.fn = fn
self.is_async = inspect.iscoroutinefunction(fn)
self.is_async = inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(
fn.__call__ # type: ignore
)
self.cancel_thread_raiser: Optional[
temporalio.worker._activity._ThreadExceptionRaiser
] = None
Expand All @@ -163,11 +168,14 @@ def __init__(
thread_event=threading.Event(),
async_event=asyncio.Event() if self.is_async else None,
),
shield_thread_cancel_exception=None
if not self.cancel_thread_raiser
else self.cancel_thread_raiser.shielded,
shield_thread_cancel_exception=(
None
if not self.cancel_thread_raiser
else self.cancel_thread_raiser.shielded
),
payload_converter_class_or_instance=env.payload_converter,
runtime_metric_meter=env.metric_meter,
client=client if self.is_async else None,
cancellation_details=env._cancellation_details,
)
self.task: Optional[asyncio.Task] = None
Expand Down
22 changes: 14 additions & 8 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
data_converter: temporalio.converter.DataConverter,
interceptors: Sequence[Interceptor],
metric_meter: temporalio.common.MetricMeter,
client: temporalio.client.Client,
encode_headers: bool,
) -> None:
self._bridge_worker = bridge_worker
Expand All @@ -86,6 +87,7 @@ def __init__(
None
)
self._seen_sync_activity = False
self._client = client

# Validate and build activity dict
self._activities: Dict[str, temporalio.activity._Definition] = {}
Expand Down Expand Up @@ -569,11 +571,14 @@ async def _execute_activity(
heartbeat=None,
cancelled_event=running_activity.cancelled_event,
worker_shutdown_event=self._worker_shutdown_event,
shield_thread_cancel_exception=None
if not running_activity.cancel_thread_raiser
else running_activity.cancel_thread_raiser.shielded,
shield_thread_cancel_exception=(
None
if not running_activity.cancel_thread_raiser
else running_activity.cancel_thread_raiser.shielded
),
payload_converter_class_or_instance=self._data_converter.payload_converter,
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
client=self._client if not running_activity.sync else None,
cancellation_details=running_activity.cancellation_details,
)
)
Expand Down Expand Up @@ -679,7 +684,7 @@ def _raise_in_thread_if_pending_unlocked(self) -> None:


class _ActivityInboundImpl(ActivityInboundInterceptor):
def __init__(
def __init__( # type: ignore[reportMissingSuperCall]
self, worker: _ActivityWorker, running_activity: _RunningActivity
) -> None:
# We are intentionally not calling the base class's __init__ here
Expand Down Expand Up @@ -786,7 +791,7 @@ async def heartbeat_with_context(*details: Any) -> None:


class _ActivityOutboundImpl(ActivityOutboundInterceptor):
def __init__(self, worker: _ActivityWorker, info: temporalio.activity.Info) -> None:
def __init__(self, worker: _ActivityWorker, info: temporalio.activity.Info) -> None: # type: ignore[reportMissingSuperCall]
# We are intentionally not calling the base class's __init__ here
self._worker = worker
self._info = info
Expand Down Expand Up @@ -838,11 +843,12 @@ def _execute_sync_activity(
worker_shutdown_event=temporalio.activity._CompositeEvent(
thread_event=worker_shutdown_event, async_event=None
),
shield_thread_cancel_exception=None
if not cancel_thread_raiser
else cancel_thread_raiser.shielded,
shield_thread_cancel_exception=(
None if not cancel_thread_raiser else cancel_thread_raiser.shielded
),
payload_converter_class_or_instance=payload_converter_class_or_instance,
runtime_metric_meter=runtime_metric_meter,
client=None,
cancellation_details=cancellation_details,
)
)
Expand Down
16 changes: 9 additions & 7 deletions temporalio/worker/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,10 @@ def __init__(
data_converter=client_config["data_converter"],
interceptors=interceptors,
metric_meter=self._runtime.metric_meter,
encode_headers=client_config["header_codec_behavior"]
== HeaderCodecBehavior.CODEC,
client=client,
encode_headers=(
client_config["header_codec_behavior"] == HeaderCodecBehavior.CODEC
),
)
self._nexus_worker: Optional[_NexusWorker] = None
if nexus_service_handlers:
Expand Down Expand Up @@ -577,12 +579,12 @@ def config(self) -> WorkerConfig:
@property
def task_queue(self) -> str:
"""Task queue this worker is on."""
return self._config["task_queue"]
return self._config["task_queue"] # type: ignore[reportTypedDictNotRequiredAccess]

@property
def client(self) -> temporalio.client.Client:
"""Client currently set on the worker."""
return self._config["client"]
return self._config["client"] # type: ignore[reportTypedDictNotRequiredAccess]

@client.setter
def client(self, value: temporalio.client.Client) -> None:
Expand Down Expand Up @@ -679,9 +681,9 @@ async def raise_on_shutdown():
)
if exception:
logger.error("Worker failed, shutting down", exc_info=exception)
if self._config["on_fatal_error"]:
if self._config["on_fatal_error"]: # type: ignore[reportTypedDictNotRequiredAccess]
try:
await self._config["on_fatal_error"](exception)
await self._config["on_fatal_error"](exception) # type: ignore[reportTypedDictNotRequiredAccess]
except:
logger.warning("Fatal error handler failed")

Expand All @@ -692,7 +694,7 @@ async def raise_on_shutdown():

# Cancel the shutdown task (safe if already done)
tasks[None].cancel()
graceful_timeout = self._config["graceful_shutdown_timeout"]
graceful_timeout = self._config["graceful_shutdown_timeout"] # type: ignore[reportTypedDictNotRequiredAccess]
logger.info(
f"Beginning worker shutdown, will wait {graceful_timeout} before cancelling activities"
)
Expand Down
45 changes: 45 additions & 0 deletions tests/testing/test_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
import threading
import time
from contextvars import copy_context
from unittest.mock import Mock

import pytest

from temporalio import activity
from temporalio.client import Client
from temporalio.exceptions import CancelledError
from temporalio.testing import ActivityEnvironment

Expand Down Expand Up @@ -122,3 +126,44 @@ async def assert_equals(a: str, b: str) -> None:

assert type(expected_err) == type(actual_err)
assert str(expected_err) == str(actual_err)


async def test_error_on_access_client_in_activity_environment_without_client():
saw_error: bool = False

async def my_activity() -> None:
with pytest.raises(RuntimeError, match="No client available"):
activity.client()
nonlocal saw_error
saw_error = True

env = ActivityEnvironment()
await env.run(my_activity)
assert saw_error


async def test_access_client_in_activity_environment_with_client():
got_client: bool = False

async def my_activity() -> None:
nonlocal got_client
if activity.client():
got_client = True

env = ActivityEnvironment(client=Mock(spec=Client))
await env.run(my_activity)
assert got_client


async def test_error_on_access_client_in_sync_activity_in_environment_with_client():
saw_error: bool = False

def my_activity() -> None:
with pytest.raises(RuntimeError, match="No client available"):
activity.client()
nonlocal saw_error
saw_error = True

env = ActivityEnvironment(client=Mock(spec=Client))
env.run(my_activity)
assert saw_error
45 changes: 44 additions & 1 deletion tests/worker/test_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,49 @@ async def get_name(name: str) -> str:
assert result.result == "Name: my custom activity name!"


async def test_client_available_in_async_activities(
client: Client, worker: ExternalWorker
):
with pytest.raises(RuntimeError, match="Not in activity context"):
activity.client()

captured_client: Optional[Client] = None

@activity.defn
async def capture_client() -> None:
nonlocal captured_client
captured_client = activity.client()

await _execute_workflow_with_activity(client, worker, capture_client)
assert captured_client is client


async def test_client_not_available_in_sync_activities(
client: Client, worker: ExternalWorker
):
saw_error = False

@activity.defn
def some_activity() -> None:
with pytest.raises(
RuntimeError, match="The client is only available in `async def`"
):
activity.client()
nonlocal saw_error
saw_error = True

await _execute_workflow_with_activity(
client,
worker,
some_activity,
worker_config={
"activity_executor": concurrent.futures.ThreadPoolExecutor(1),
"max_concurrent_activities": 1,
},
)
assert saw_error


async def test_activity_info(
client: Client, worker: ExternalWorker, env: WorkflowEnvironment
):
Expand Down Expand Up @@ -612,7 +655,7 @@ async def some_activity(param1: SomeClass2, param2: str) -> str:
result.result
== "param1: <class 'tests.worker.test_activity.SomeClass2'>, param2: <class 'str'>"
)
assert activity_param1 == SomeClass2(foo="str1", bar=SomeClass1(foo=123))
assert activity_param1 == SomeClass2(foo="str1", bar=SomeClass1(foo=123)) # type: ignore[reportUnboundVariable] # noqa


async def test_activity_heartbeat_details(
Expand Down
Loading