diff --git a/azure/durable_functions/models/DurableEntityContext.py b/azure/durable_functions/models/DurableEntityContext.py index c4c13879..ef91b13f 100644 --- a/azure/durable_functions/models/DurableEntityContext.py +++ b/azure/durable_functions/models/DurableEntityContext.py @@ -124,7 +124,9 @@ def set_state(self, state: Any) -> None: """ # TODO: enable serialization of custom types self._exists = True - self._state = json.dumps(state) + + # should only serialize the state at the end of a batch + self._state = state def get_state(self, initializer: Optional[Callable[[], Any]] = None) -> Any: """Get the current state of this entity. diff --git a/azure/durable_functions/models/DurableOrchestrationContext.py b/azure/durable_functions/models/DurableOrchestrationContext.py index f2d96186..26dac506 100644 --- a/azure/durable_functions/models/DurableOrchestrationContext.py +++ b/azure/durable_functions/models/DurableOrchestrationContext.py @@ -12,7 +12,8 @@ from .utils.entity_utils import EntityId from ..tasks import call_activity_task, task_all, task_any, call_activity_with_retry_task, \ wait_for_external_event_task, continue_as_new, new_uuid, call_http, create_timer_task, \ - call_sub_orchestrator_task, call_sub_orchestrator_with_retry_task, call_entity_task + call_sub_orchestrator_task, call_sub_orchestrator_with_retry_task, call_entity_task, \ + signal_entity_task from azure.functions._durable_functions import _deserialize_custom_object @@ -380,6 +381,26 @@ def call_entity(self, entityId: EntityId, """ return call_entity_task(self.histories, entityId, operationName, operationInput) + def signal_entity(self, entityId: EntityId, + operationName: str, operationInput: Optional[Any] = None): + """Send a signal operation to Durable Entity given some input. + + Parameters + ---------- + entityId: EntityId + The ID of the entity to call + operationName: str + The operation to execute + operationInput: Optional[Any] + The input for tne operation, defaults to None. + + Returns + ------- + Task + A Task of the entity signal + """ + return signal_entity_task(self, self.histories, entityId, operationName, operationInput) + @property def will_continue_as_new(self) -> bool: """Return true if continue_as_new was called.""" diff --git a/azure/durable_functions/models/actions/ActionType.py b/azure/durable_functions/models/actions/ActionType.py index 66ffbb63..406c6f86 100644 --- a/azure/durable_functions/models/actions/ActionType.py +++ b/azure/durable_functions/models/actions/ActionType.py @@ -13,3 +13,4 @@ class ActionType(IntEnum): WAIT_FOR_EXTERNAL_EVENT: int = 6 CALL_ENTITY = 7 CALL_HTTP: int = 8 + SIGNAL_ENTITY: int = 9 diff --git a/azure/durable_functions/models/actions/SignalEntityAction.py b/azure/durable_functions/models/actions/SignalEntityAction.py new file mode 100644 index 00000000..d6e9be54 --- /dev/null +++ b/azure/durable_functions/models/actions/SignalEntityAction.py @@ -0,0 +1,47 @@ +from typing import Any, Dict + +from .Action import Action +from .ActionType import ActionType +from ..utils.json_utils import add_attrib +from json import dumps +from azure.functions._durable_functions import _serialize_custom_object +from ..utils.entity_utils import EntityId + + +class SignalEntityAction(Action): + """Defines the structure of the Signal Entity object. + + Provides the information needed by the durable extension to be able to signal an entity + """ + + def __init__(self, entity_id: EntityId, operation: str, input_=None): + self.entity_id: EntityId = entity_id + + # Validating that EntityId exists before trying to parse its instanceId + if not self.entity_id: + raise ValueError("entity_id cannot be empty") + + self.instance_id: str = EntityId.get_scheduler_id(entity_id) + self.operation: str = operation + self.input_: str = dumps(input_, default=_serialize_custom_object) + + @property + def action_type(self) -> int: + """Get the type of action this class represents.""" + return ActionType.SIGNAL_ENTITY + + def to_json(self) -> Dict[str, Any]: + """Convert object into a json dictionary. + + Returns + ------- + Dict[str, Any] + The instance of the class converted into a json dictionary + """ + json_dict: Dict[str, Any] = {} + add_attrib(json_dict, self, "action_type", "actionType") + add_attrib(json_dict, self, 'instance_id', 'instanceId') + add_attrib(json_dict, self, 'operation', 'operation') + add_attrib(json_dict, self, 'input_', 'input') + + return json_dict diff --git a/azure/durable_functions/models/entities/EntityState.py b/azure/durable_functions/models/entities/EntityState.py index 59fea5ad..e0de6b6e 100644 --- a/azure/durable_functions/models/entities/EntityState.py +++ b/azure/durable_functions/models/entities/EntityState.py @@ -55,7 +55,7 @@ def to_json(self) -> Dict[str, Any]: serialized_results = list(map(lambda x: x.to_json(), self.results)) json_dict["entityExists"] = self.entity_exists - json_dict["entityState"] = self.state + json_dict["entityState"] = json.dumps(self.state) json_dict["results"] = serialized_results json_dict["signals"] = self.signals return json_dict diff --git a/azure/durable_functions/tasks/__init__.py b/azure/durable_functions/tasks/__init__.py index 82eb8be0..9c7f6e9b 100644 --- a/azure/durable_functions/tasks/__init__.py +++ b/azure/durable_functions/tasks/__init__.py @@ -12,6 +12,7 @@ from .call_http import call_http from .create_timer import create_timer_task from .call_entity import call_entity_task +from .signal_entity import signal_entity_task __all__ = [ 'call_activity_task', @@ -19,6 +20,7 @@ 'call_sub_orchestrator_task', 'call_sub_orchestrator_with_retry_task', 'call_entity_task', + 'signal_entity_task', 'call_http', 'continue_as_new', 'new_uuid', diff --git a/azure/durable_functions/tasks/call_entity.py b/azure/durable_functions/tasks/call_entity.py index 9d22989f..38ecb7cb 100644 --- a/azure/durable_functions/tasks/call_entity.py +++ b/azure/durable_functions/tasks/call_entity.py @@ -8,6 +8,7 @@ from ..models.utils.entity_utils import EntityId from ..models.entities.RequestMessage import RequestMessage from ..models.entities.ResponseMessage import ResponseMessage +import json def call_entity_task( @@ -64,7 +65,12 @@ def call_entity_task( if event_raised is not None: response = parse_history_event(event_raised) response = ResponseMessage.from_dict(response) - result = response.result + + # TODO: json.loads inside parse_history_event is not recursvie + # investigate if response.result is used elsewhere, + # which probably requires another deserialization + result = json.loads(response.result) + return Task( is_completed=True, is_faulted=False, diff --git a/azure/durable_functions/tasks/signal_entity.py b/azure/durable_functions/tasks/signal_entity.py new file mode 100644 index 00000000..2f17d0ce --- /dev/null +++ b/azure/durable_functions/tasks/signal_entity.py @@ -0,0 +1,46 @@ +from typing import List, Any, Optional +from ..models.actions.SignalEntityAction import SignalEntityAction +from ..models.history import HistoryEvent, HistoryEventType +from .task_utilities import set_processed, find_event +from ..models.utils.entity_utils import EntityId + + +def signal_entity_task( + context, + state: List[HistoryEvent], + entity_id: EntityId, + operation_name: str = "", + input_: Optional[Any] = None): + """Signal a entity operation. + + It the action hasn't been scheduled, it appends the action. + If the action has been scheduled, no ops. + + Parameters + ---------- + state: List[HistoryEvent] + The list of history events to search over to determine the + current state of the callEntity Task. + entity_id: EntityId + An identifier for the entity to call. + operation_name: str + The name of the operation the entity needs to execute. + input_: Any + The JSON-serializable input to pass to the activity function. + """ + new_action = SignalEntityAction(entity_id, operation_name, input_) + scheduler_id = EntityId.get_scheduler_id(entity_id=entity_id) + + hist_type = HistoryEventType.EVENT_SENT + extra_constraints = { + "InstanceId": scheduler_id, + "Name": "op" + } + + event_sent = find_event(state, hist_type, extra_constraints) + set_processed([event_sent]) + + if event_sent: + return + + context.actions.append([new_action]) diff --git a/tests/models/test_DurableOrchestrationClient.py b/tests/models/test_DurableOrchestrationClient.py index 1b97629a..188c52ea 100644 --- a/tests/models/test_DurableOrchestrationClient.py +++ b/tests/models/test_DurableOrchestrationClient.py @@ -492,7 +492,7 @@ async def test_wait_or_response_check_status_response(binding_string): @pytest.mark.asyncio -async def test_wait_or_response_check_status_response(binding_string): +async def test_wait_or_response_null_request(binding_string): status = dict(createdTime=TEST_CREATED_TIME, lastUpdatedTime=TEST_LAST_UPDATED_TIME, runtimeStatus="Running") diff --git a/tests/orchestrator/test_entity.py b/tests/orchestrator/test_entity.py index 9878e1ea..ca10723b 100644 --- a/tests/orchestrator/test_entity.py +++ b/tests/orchestrator/test_entity.py @@ -4,11 +4,13 @@ from azure.durable_functions.models.OrchestratorState import OrchestratorState from azure.durable_functions.models.actions.CallEntityAction \ import CallEntityAction +from azure.durable_functions.models.actions.SignalEntityAction \ + import SignalEntityAction from tests.test_utils.testClasses import SerializableClass import azure.durable_functions as df from typing import Any -def generator_function(context): +def generator_function_call_entity(context): outputs = [] entityId = df.EntityId("Counter", "myCounter") x = yield context.call_entity(entityId, "add", 3) @@ -16,15 +18,25 @@ def generator_function(context): outputs.append(x) return outputs +def generator_function_signal_entity(context): + outputs = [] + entityId = df.EntityId("Counter", "myCounter") + context.signal_entity(entityId, "add", 3) + x = yield context.call_entity(entityId, "get") + + outputs.append(x) + return outputs def base_expected_state(output=None) -> OrchestratorState: return OrchestratorState(is_done=False, actions=[], output=output) - def add_call_entity_action(state: OrchestratorState, id_: df.EntityId, op: str, input_: Any): action = CallEntityAction(entity_id=id_, operation=op, input_=input_) state.actions.append([action]) +def add_signal_entity_action(state: OrchestratorState, id_: df.EntityId, op: str, input_: Any): + action = SignalEntityAction(entity_id=id_, operation=op, input_=input_) + state.actions.append([action]) def add_call_entity_completed_events( context_builder: ContextBuilder, op: str, instance_id=str, input_=None): @@ -38,7 +50,7 @@ def test_call_entity_sent(): entityId = df.EntityId("Counter", "myCounter") result = get_orchestration_state_result( - context_builder, generator_function) + context_builder, generator_function_call_entity) expected_state = base_expected_state() add_call_entity_action(expected_state, entityId, "add", 3) @@ -46,14 +58,30 @@ def test_call_entity_sent(): #assert_valid_schema(result) assert_orchestration_state_equals(expected, result) + +def test_signal_entity_sent(): + context_builder = ContextBuilder('test_simple_function') + + entityId = df.EntityId("Counter", "myCounter") + result = get_orchestration_state_result( + context_builder, generator_function_signal_entity) + + expected_state = base_expected_state() + add_signal_entity_action(expected_state, entityId, "add", 3) + add_call_entity_action(expected_state, entityId, "get", None) + expected = expected_state.to_json() + + #assert_valid_schema(result) + assert_orchestration_state_equals(expected, result) + def test_call_entity_raised(): entityId = df.EntityId("Counter", "myCounter") context_builder = ContextBuilder('test_simple_function') - add_call_entity_completed_events(context_builder, "add", df.EntityId.get_scheduler_id(entityId),3) + add_call_entity_completed_events(context_builder, "add", df.EntityId.get_scheduler_id(entityId), 3) result = get_orchestration_state_result( - context_builder, generator_function) + context_builder, generator_function_call_entity) expected_state = base_expected_state( [3] diff --git a/tests/test_utils/ContextBuilder.py b/tests/test_utils/ContextBuilder.py index 8d7f646f..69b884cd 100644 --- a/tests/test_utils/ContextBuilder.py +++ b/tests/test_utils/ContextBuilder.py @@ -120,7 +120,7 @@ def add_event_raised_event(self, name:str, id_: int, input_=None, timestamp=None event = self.get_base_event(HistoryEventType.EVENT_RAISED, id_=id_, timestamp=timestamp) event.Name = name if is_entity: - event.Input = json.dumps({ "result": input_ }) + event.Input = json.dumps({ "result": json.dumps(input_) }) else: event.Input = input_ # event.timestamp = timestamp