diff --git a/poetry.lock b/poetry.lock index f4c62c35b..ac912cf8d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -373,11 +373,11 @@ python-versions = "*" [[package]] name = "isort" -version = "5.10.1" +version = "5.11.4" description = "A Python utility / library to sort Python imports." category = "dev" optional = false -python-versions = ">=3.6.1,<4.0" +python-versions = ">=3.7.0" [package.extras] colors = ["colorama (>=0.4.3,<0.5.0)"] @@ -1093,7 +1093,7 @@ opentelemetry = ["opentelemetry-api", "opentelemetry-sdk"] [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "46b2898a5a38a4f36b712d15d27bfdcadb4f87bae5127962807ba4f5aa5934bc" +content-hash = "5e9568fae21b43278f1acf8349199ecb7408bcb77b55b47b54c1c52e80f00d17" [metadata.files] appdirs = [ @@ -1410,8 +1410,8 @@ iniconfig = [ {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, ] isort = [ - {file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"}, - {file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"}, + {file = "isort-5.11.4-py3-none-any.whl", hash = "sha256:c033fd0edb91000a7f09527fe5c75321878f98322a77ddcc81adbd83724afb7b"}, + {file = "isort-5.11.4.tar.gz", hash = "sha256:6db30c5ded9815d813932c04c2f85a360bcdd35fed496f4d8f35495ef0a261b6"}, ] "jaraco.classes" = [ {file = "jaraco.classes-3.2.2-py3-none-any.whl", hash = "sha256:e6ef6fd3fcf4579a7a019d87d1e56a883f4e4c35cfe925f86731abc58804e647"}, diff --git a/pyproject.toml b/pyproject.toml index 9d58e7d23..a6da093ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ typing-extensions = "^4.2.0" black = "^22.3.0" cibuildwheel = "^2.11.0" grpcio-tools = "^1.48.0" -isort = "^5.10.1" +isort = "^5.11.3" mypy = "^0.971" mypy-protobuf = "^3.3.0" protoc-wheel-0 = "^21.1" diff --git a/temporalio/converter.py b/temporalio/converter.py index cac28ada0..86c7fe11a 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -573,6 +573,13 @@ async def _apply_to_failure_payloads( failure: temporalio.api.failure.v1.Failure, cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]], ) -> None: + if failure.HasField("encoded_attributes"): + # Wrap in payloads and merge back + payloads = temporalio.api.common.v1.Payloads( + payloads=[failure.encoded_attributes] + ) + await cb(payloads) + failure.encoded_attributes.CopyFrom(payloads.payloads[0]) if failure.HasField( "application_failure_info" ) and failure.application_failure_info.HasField("details"): diff --git a/tests/test_converter.py b/tests/test_converter.py index 5ed5e99de..6cb359d24 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -33,9 +33,19 @@ import temporalio.api.common.v1 import temporalio.common -import temporalio.converter +from temporalio.api.common.v1 import Payload from temporalio.api.common.v1 import Payload as AnotherNameForPayload +from temporalio.api.common.v1 import Payloads from temporalio.api.failure.v1 import Failure +from temporalio.converter import ( + BinaryProtoPayloadConverter, + DataConverter, + DefaultFailureConverterWithEncodedAttributes, + JSONPlainPayloadConverter, + PayloadCodec, + decode_search_attributes, + encode_search_attribute_values, +) from temporalio.exceptions import ApplicationError, FailureError # StrEnum is available in 3.11+ @@ -77,7 +87,7 @@ async def assert_payload( expected_decoded_input=None, type_hint=None, ): - payloads = await temporalio.converter.DataConverter().encode([input]) + payloads = await DataConverter().encode([input]) # Check encoding and data assert len(payloads) == 1 if isinstance(expected_encoding, str): @@ -87,9 +97,7 @@ async def assert_payload( expected_data = expected_data.encode() assert payloads[0].data == expected_data # Decode and check - actual_inputs = await temporalio.converter.DataConverter().decode( - payloads, [type_hint] - ) + actual_inputs = await DataConverter().decode(payloads, [type_hint]) assert len(actual_inputs) == 1 if expected_decoded_input is None: expected_decoded_input = input @@ -158,7 +166,7 @@ async def assert_payload( def test_binary_proto(): # We have to test this separately because by default it never encodes # anything since JSON proto takes precedence - conv = temporalio.converter.BinaryProtoPayloadConverter() + conv = BinaryProtoPayloadConverter() proto = temporalio.api.common.v1.WorkflowExecution(workflow_id="id1", run_id="id2") payload = conv.to_payload(proto) assert payload.metadata["encoding"] == b"binary/protobuf" @@ -172,11 +180,11 @@ def test_binary_proto(): def test_encode_search_attribute_values(): with pytest.raises(TypeError, match="of type tuple not one of"): - temporalio.converter.encode_search_attribute_values([("bad type",)]) + encode_search_attribute_values([("bad type",)]) with pytest.raises(ValueError, match="Timezone must be present"): - temporalio.converter.encode_search_attribute_values([datetime.utcnow()]) + encode_search_attribute_values([datetime.utcnow()]) with pytest.raises(TypeError, match="must have the same type"): - temporalio.converter.encode_search_attribute_values(["foo", 123]) + encode_search_attribute_values(["foo", 123]) def test_decode_search_attributes(): @@ -192,25 +200,23 @@ def payload(key, dtype, data, encoding=None): return temporalio.api.common.v1.SearchAttributes(indexed_fields={key: check}) # Check basic keyword parsing works - kw_check = temporalio.converter.decode_search_attributes( - payload("kw", "Keyword", '"test-id"') - ) + kw_check = decode_search_attributes(payload("kw", "Keyword", '"test-id"')) assert kw_check["kw"][0] == "test-id" # Ensure original DT functionality works - dt_check = temporalio.converter.decode_search_attributes( + dt_check = decode_search_attributes( payload("dt", "Datetime", '"2020-01-01T00:00:00"') ) assert dt_check["dt"][0] == datetime(2020, 1, 1, 0, 0, 0) # Check timezone aware works as server is using ISO 8601 - dttz_check = temporalio.converter.decode_search_attributes( + dttz_check = decode_search_attributes( payload("dt", "Datetime", '"2020-01-01T00:00:00Z"') ) assert dttz_check["dt"][0] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc) # Check timezone aware, hour offset - dttz_check = temporalio.converter.decode_search_attributes( + dttz_check = decode_search_attributes( payload("dt", "Datetime", '"2020-01-01T00:00:00+00:00"') ) assert dttz_check["dt"][0] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc) @@ -245,7 +251,7 @@ class MyPydanticClass(pydantic.BaseModel): def test_json_type_hints(): - converter = temporalio.converter.JSONPlainPayloadConverter() + converter = JSONPlainPayloadConverter() def ok( hint: Type, value: Any, expected_result: Any = temporalio.common._arg_unset @@ -415,10 +421,8 @@ async def test_exception_format(): # Convert to failure and back failure = Failure() - await temporalio.converter.DataConverter.default.encode_failure(actual_err, failure) - failure_error = await temporalio.converter.DataConverter.default.decode_failure( - failure - ) + await DataConverter.default.encode_failure(actual_err, failure) + failure_error = await DataConverter.default.decode_failure(failure) # Confirm type is prepended assert isinstance(failure_error, ApplicationError) assert "RuntimeError: error2" == str(failure_error) @@ -440,3 +444,72 @@ async def test_exception_format(): logging.getLogger(__name__).debug( "Showing appended exception", exc_info=failure_error ) + + +# Just serializes in a "payloads" wrapper +class SimpleCodec(PayloadCodec): + async def encode(self, payloads: Sequence[Payload]) -> List[Payload]: + wrapper = Payloads(payloads=payloads) + return [ + Payload( + metadata={"simple-codec": b"true"}, data=wrapper.SerializeToString() + ) + ] + + async def decode(self, payloads: Sequence[Payload]) -> List[Payload]: + payloads = list(payloads) + if len(payloads) != 1: + raise RuntimeError("Expected only a single payload") + elif payloads[0].metadata.get("simple-codec") != b"true": + raise RuntimeError("Not encoded with this codec") + wrapper = Payloads() + wrapper.ParseFromString(payloads[0].data) + return list(wrapper.payloads) + + +async def test_failure_encoded_attributes(): + try: + raise ApplicationError("some message", "some detail") + except ApplicationError as err: + some_err = err + + conv = DataConverter( + failure_converter_class=DefaultFailureConverterWithEncodedAttributes, + payload_codec=SimpleCodec(), + ) + + # Check failure + failure = Failure() + conv.failure_converter.to_failure(some_err, conv.payload_converter, failure) + assert failure.message == "Encoded failure" + assert failure.stack_trace == "" + assert conv.payload_converter.from_payloads( + failure.application_failure_info.details.payloads + ) == ["some detail"] + encoded_attr = conv.payload_converter.from_payloads([failure.encoded_attributes])[0] + assert encoded_attr["message"] == "some message" + assert "test_converter" in encoded_attr["stack_trace"] + + # Encode it and check encoded + orig_failure = Failure() + orig_failure.CopyFrom(failure) + await conv.payload_codec.encode_failure(failure) + assert "encoding" not in failure.encoded_attributes.metadata + assert "simple-codec" in failure.encoded_attributes.metadata + assert ( + "encoding" not in failure.application_failure_info.details.payloads[0].metadata + ) + assert ( + "simple-codec" in failure.application_failure_info.details.payloads[0].metadata + ) + + # Decode and check + await conv.payload_codec.decode_failure(failure) + assert "encoding" in failure.encoded_attributes.metadata + assert "simple-codec" not in failure.encoded_attributes.metadata + assert "encoding" in failure.application_failure_info.details.payloads[0].metadata + assert ( + "simple-codec" + not in failure.application_failure_info.details.payloads[0].metadata + ) + assert failure == orig_failure