Skip to content

Deprecate async def queries and disallow most workflow operations in read-only context #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ class GreetingWorkflow:
self._complete.set()

@workflow.query
async def current_greeting(self) -> str:
def current_greeting(self) -> str:
return self._current_greeting

```
Expand Down Expand Up @@ -566,7 +566,8 @@ Here are the decorators that can be applied:
* Return value is ignored
* `@workflow.query` - Defines a method as a query
* All the same constraints as `@workflow.signal` but should return a value
* Temporal queries should never mutate anything in the workflow
* Should not be `async`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just deprecate and send a warning

* Temporal queries should never mutate anything in the workflow or call any calls that would mutate the workflow

#### Running

Expand Down
97 changes: 69 additions & 28 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import traceback
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import (
Expand All @@ -21,6 +22,7 @@
Deque,
Dict,
Generator,
Iterator,
List,
Mapping,
MutableMapping,
Expand Down Expand Up @@ -193,6 +195,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
self._object: Any = None
self._is_replaying: bool = False
self._random = random.Random(det.randomness_seed)
self._read_only = False

# Patches we have been notified of and memoized patch responses
self._patches_notified: Set[str] = set()
Expand Down Expand Up @@ -421,36 +424,39 @@ async def run_query() -> None:
command = self._add_command()
command.respond_to_query.query_id = job.query_id
try:
# Named query or dynamic
defn = self._queries.get(job.query_type) or self._queries.get(None)
if not defn:
known_queries = sorted([k for k in self._queries.keys() if k])
raise RuntimeError(
f"Query handler for '{job.query_type}' expected but not found, "
f"known queries: [{' '.join(known_queries)}]"
with self._as_read_only():
# Named query or dynamic
defn = self._queries.get(job.query_type) or self._queries.get(None)
if not defn:
known_queries = sorted([k for k in self._queries.keys() if k])
raise RuntimeError(
f"Query handler for '{job.query_type}' expected but not found, "
f"known queries: [{' '.join(known_queries)}]"
)

# Create input
args = self._process_handler_args(
job.query_type,
job.arguments,
defn.name,
defn.arg_types,
defn.dynamic_vararg,
)

# Create input
args = self._process_handler_args(
job.query_type,
job.arguments,
defn.name,
defn.arg_types,
defn.dynamic_vararg,
)
input = HandleQueryInput(
id=job.query_id,
query=job.query_type,
args=args,
headers=job.headers,
)
success = await self._inbound.handle_query(input)
result_payloads = self._payload_converter.to_payloads([success])
if len(result_payloads) != 1:
raise ValueError(
f"Expected 1 result payload, got {len(result_payloads)}"
input = HandleQueryInput(
id=job.query_id,
query=job.query_type,
args=args,
headers=job.headers,
)
success = await self._inbound.handle_query(input)
result_payloads = self._payload_converter.to_payloads([success])
if len(result_payloads) != 1:
raise ValueError(
f"Expected 1 result payload, got {len(result_payloads)}"
)
command.respond_to_query.succeeded.response.CopyFrom(
result_payloads[0]
)
command.respond_to_query.succeeded.response.CopyFrom(result_payloads[0])
except Exception as err:
try:
self._failure_converter.to_failure(
Expand Down Expand Up @@ -695,6 +701,7 @@ def workflow_continue_as_new(
search_attributes: Optional[temporalio.common.SearchAttributes],
versioning_intent: Optional[temporalio.workflow.VersioningIntent],
) -> NoReturn:
self._assert_not_read_only("continue as new")
# Use definition if callable
name: Optional[str] = None
arg_types: Optional[List[Type]] = None
Expand Down Expand Up @@ -795,12 +802,20 @@ def workflow_payload_converter(self) -> temporalio.converter.PayloadConverter:
return self._payload_converter

def workflow_random(self) -> random.Random:
self._assert_not_read_only("random")
return self._random

def workflow_set_query_handler(
self, name: Optional[str], handler: Optional[Callable]
) -> None:
self._assert_not_read_only("set query handler")
if handler:
if inspect.iscoroutinefunction(handler):
warnings.warn(
"Queries as async def functions are deprecated",
DeprecationWarning,
stacklevel=3,
)
defn = temporalio.workflow._QueryDefinition(
name=name, fn=handler, is_method=False
)
Expand All @@ -817,6 +832,7 @@ def workflow_set_query_handler(
def workflow_set_signal_handler(
self, name: Optional[str], handler: Optional[Callable]
) -> None:
self._assert_not_read_only("set signal handler")
if handler:
defn = temporalio.workflow._SignalDefinition(
name=name, fn=handler, is_method=False
Expand Down Expand Up @@ -855,6 +871,7 @@ def workflow_start_activity(
activity_id: Optional[str],
versioning_intent: Optional[temporalio.workflow.VersioningIntent],
) -> temporalio.workflow.ActivityHandle[Any]:
self._assert_not_read_only("start activity")
# Get activity definition if it's callable
name: str
arg_types: Optional[List[Type]] = None
Expand Down Expand Up @@ -1012,6 +1029,7 @@ def workflow_upsert_search_attributes(
async def workflow_wait_condition(
self, fn: Callable[[], bool], *, timeout: Optional[float] = None
) -> None:
self._assert_not_read_only("wait condition")
fut = self.create_future()
self._conditions.append((fn, fut))
await asyncio.wait_for(fut, timeout)
Expand Down Expand Up @@ -1153,8 +1171,24 @@ async def run_child() -> Any:
# These are in alphabetical order.

def _add_command(self) -> temporalio.bridge.proto.workflow_commands.WorkflowCommand:
self._assert_not_read_only("add command")
return self._current_completion.successful.commands.add()

@contextmanager
def _as_read_only(self) -> Iterator[None]:
prev_val = self._read_only
self._read_only = True
try:
yield None
finally:
self._read_only = prev_val

def _assert_not_read_only(self, action_attempted: str) -> None:
if self._read_only:
raise temporalio.workflow.ReadOnlyContextError(
f"While in read-only function, action attempted: {action_attempted}"
)

async def _cancel_external_workflow(
self,
# Should not have seq set
Expand Down Expand Up @@ -1258,6 +1292,7 @@ def _register_task(
*,
name: Optional[str],
) -> None:
self._assert_not_read_only("create task")
# Name not supported on older Python versions
if sys.version_info >= (3, 8):
# Put the workflow info at the end of the task name
Expand Down Expand Up @@ -1423,6 +1458,7 @@ def call_soon(
*args: Any,
context: Optional[contextvars.Context] = None,
) -> asyncio.Handle:
self._assert_not_read_only("schedule task")
handle = asyncio.Handle(callback, args, self, context)
self._ready.append(handle)
return handle
Expand All @@ -1434,6 +1470,7 @@ def call_later(
*args: Any,
context: Optional[contextvars.Context] = None,
) -> asyncio.TimerHandle:
self._assert_not_read_only("schedule timer")
# Delay must be positive
if delay < 0:
raise RuntimeError("Attempting to schedule timer with negative delay")
Expand Down Expand Up @@ -1675,6 +1712,7 @@ def __init__(
instance._register_task(self, name=f"activity: {input.activity}")

def cancel(self, msg: Optional[Any] = None) -> bool:
self._instance._assert_not_read_only("cancel activity handle")
# We override this because if it's not yet started and not done, we need
# to send a cancel command because the async function won't run to trap
# the cancel (i.e. cancelled before started)
Expand Down Expand Up @@ -1821,6 +1859,7 @@ async def signal(
*,
args: Sequence[Any] = [],
) -> None:
self._instance._assert_not_read_only("signal child handle")
await self._instance._outbound.signal_child_workflow(
SignalChildWorkflowInput(
signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str(
Expand Down Expand Up @@ -1935,6 +1974,7 @@ async def signal(
*,
args: Sequence[Any] = [],
) -> None:
self._instance._assert_not_read_only("signal external handle")
await self._instance._outbound.signal_external_workflow(
SignalExternalWorkflowInput(
signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str(
Expand All @@ -1949,6 +1989,7 @@ async def signal(
)

async def cancel(self) -> None:
self._instance._assert_not_read_only("cancel external handle")
command = self._instance._add_command()
v = command.request_cancel_external_workflow_execution
v.workflow_execution.namespace = self._instance._info.namespace
Expand Down
34 changes: 29 additions & 5 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ def query(
):
"""Decorator for a workflow query method.

This is set on any async or non-async method that expects to handle a
query. If a function overrides one with this decorator, it too must be
decorated.
This is set on any non-async method that expects to handle a query. If a
function overrides one with this decorator, it too must be decorated.

Query methods can only have positional parameters. Best practice for
non-dynamic query methods is to only take a single object/dataclass
Expand All @@ -262,7 +261,15 @@ def query(
present.
"""

def with_name(name: Optional[str], fn: CallableType) -> CallableType:
def with_name(
name: Optional[str], fn: CallableType, *, bypass_async_check: bool = False
) -> CallableType:
if not bypass_async_check and inspect.iscoroutinefunction(fn):
warnings.warn(
"Queries as async def functions are deprecated",
DeprecationWarning,
stacklevel=2,
)
defn = _QueryDefinition(name=name, fn=fn, is_method=True)
setattr(fn, "__temporal_query_definition", defn)
if defn.dynamic_vararg:
Expand All @@ -279,7 +286,13 @@ def with_name(name: Optional[str], fn: CallableType) -> CallableType:
return partial(with_name, name)
if fn is None:
raise RuntimeError("Cannot create query without function or name or dynamic")
return with_name(fn.__name__, fn)
if inspect.iscoroutinefunction(fn):
warnings.warn(
"Queries as async def functions are deprecated",
DeprecationWarning,
stacklevel=2,
)
return with_name(fn.__name__, fn, bypass_async_check=True)


@dataclass(frozen=True)
Expand Down Expand Up @@ -3919,6 +3932,17 @@ def __init__(self, message: str) -> None:
self.message = message


class ReadOnlyContextError(temporalio.exceptions.TemporalError):
"""Error thrown when trying to do mutable workflow calls in a read-only
context like a query or update validator.
"""

def __init__(self, message: str) -> None:
"""Initialize a read-only context error."""
super().__init__(message)
self.message = message


class _NotInWorkflowEventLoopError(temporalio.exceptions.TemporalError):
def __init__(self, *args: object) -> None:
super().__init__("Not in workflow event loop")
Expand Down
2 changes: 1 addition & 1 deletion tests/testing/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def run(self) -> str:
return "all done"

@workflow.query
async def current_time(self) -> float:
def current_time(self) -> float:
return workflow.now().timestamp()

@workflow.signal
Expand Down
59 changes: 58 additions & 1 deletion tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3014,7 +3014,7 @@ async def signal(self) -> None:
self._signal_count += 1

@workflow.query
async def signal_count(self) -> int:
def signal_count(self) -> int:
return self._signal_count


Expand Down Expand Up @@ -3097,6 +3097,63 @@ async def test_workflow_dynamic(client: Client):
assert result == DynamicWorkflowValue("some-workflow - val1 - val2")


@workflow.defn
class QueriesDoingBadThingsWorkflow:
@workflow.run
async def run(self) -> None:
await workflow.wait_condition(lambda: False)

@workflow.query
async def bad_query(self, bad_thing: str) -> str:
if bad_thing == "wait_condition":
await workflow.wait_condition(lambda: True)
elif bad_thing == "continue_as_new":
workflow.continue_as_new()
elif bad_thing == "upsert_search_attribute":
workflow.upsert_search_attributes({"foo": ["bar"]})
elif bad_thing == "start_activity":
workflow.start_activity(
"some-activity", start_to_close_timeout=timedelta(minutes=10)
)
elif bad_thing == "start_child_workflow":
await workflow.start_child_workflow("some-workflow")
elif bad_thing == "random":
workflow.random().random()
elif bad_thing == "set_query_handler":
workflow.set_query_handler("some-handler", lambda: "whatever")
elif bad_thing == "patch":
workflow.patched("some-patch")
elif bad_thing == "signal_external_handle":
await workflow.get_external_workflow_handle("some-id").signal("some-signal")
return "should never get here"


async def test_workflow_queries_doing_bad_things(client: Client):
async with new_worker(client, QueriesDoingBadThingsWorkflow) as worker:
handle = await client.start_workflow(
QueriesDoingBadThingsWorkflow.run,
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
)

async def assert_bad_query(bad_thing: str) -> None:
with pytest.raises(WorkflowQueryFailedError) as err:
_ = await handle.query(
QueriesDoingBadThingsWorkflow.bad_query, bad_thing
)
assert "While in read-only function, action attempted" in str(err)

await assert_bad_query("wait_condition")
await assert_bad_query("continue_as_new")
await assert_bad_query("upsert_search_attribute")
await assert_bad_query("start_activity")
await assert_bad_query("start_child_workflow")
await assert_bad_query("random")
await assert_bad_query("set_query_handler")
await assert_bad_query("patch")
await assert_bad_query("signal_external_handle")


# typing.Self only in 3.11+
if sys.version_info >= (3, 11):

Expand Down