Skip to content

Properly encode failure encoded attributes #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ typing-extensions = "^4.2.0"
black = "^22.3.0"
cibuildwheel = "^2.11.0"
grpcio-tools = "^1.48.0"
isort = "^5.10.1"
isort = "^5.11.3"
mypy = "^0.971"
mypy-protobuf = "^3.3.0"
protoc-wheel-0 = "^21.1"
Expand Down
7 changes: 7 additions & 0 deletions temporalio/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,13 @@ async def _apply_to_failure_payloads(
failure: temporalio.api.failure.v1.Failure,
cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]],
) -> None:
if failure.HasField("encoded_attributes"):
# Wrap in payloads and merge back
payloads = temporalio.api.common.v1.Payloads(
payloads=[failure.encoded_attributes]
)
await cb(payloads)
failure.encoded_attributes.CopyFrom(payloads.payloads[0])
if failure.HasField(
"application_failure_info"
) and failure.application_failure_info.HasField("details"):
Expand Down
113 changes: 93 additions & 20 deletions tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,19 @@

import temporalio.api.common.v1
import temporalio.common
import temporalio.converter
from temporalio.api.common.v1 import Payload
from temporalio.api.common.v1 import Payload as AnotherNameForPayload
from temporalio.api.common.v1 import Payloads
from temporalio.api.failure.v1 import Failure
from temporalio.converter import (
BinaryProtoPayloadConverter,
DataConverter,
DefaultFailureConverterWithEncodedAttributes,
JSONPlainPayloadConverter,
PayloadCodec,
decode_search_attributes,
encode_search_attribute_values,
)
from temporalio.exceptions import ApplicationError, FailureError

# StrEnum is available in 3.11+
Expand Down Expand Up @@ -77,7 +87,7 @@ async def assert_payload(
expected_decoded_input=None,
type_hint=None,
):
payloads = await temporalio.converter.DataConverter().encode([input])
payloads = await DataConverter().encode([input])
# Check encoding and data
assert len(payloads) == 1
if isinstance(expected_encoding, str):
Expand All @@ -87,9 +97,7 @@ async def assert_payload(
expected_data = expected_data.encode()
assert payloads[0].data == expected_data
# Decode and check
actual_inputs = await temporalio.converter.DataConverter().decode(
payloads, [type_hint]
)
actual_inputs = await DataConverter().decode(payloads, [type_hint])
assert len(actual_inputs) == 1
if expected_decoded_input is None:
expected_decoded_input = input
Expand Down Expand Up @@ -158,7 +166,7 @@ async def assert_payload(
def test_binary_proto():
# We have to test this separately because by default it never encodes
# anything since JSON proto takes precedence
conv = temporalio.converter.BinaryProtoPayloadConverter()
conv = BinaryProtoPayloadConverter()
proto = temporalio.api.common.v1.WorkflowExecution(workflow_id="id1", run_id="id2")
payload = conv.to_payload(proto)
assert payload.metadata["encoding"] == b"binary/protobuf"
Expand All @@ -172,11 +180,11 @@ def test_binary_proto():

def test_encode_search_attribute_values():
with pytest.raises(TypeError, match="of type tuple not one of"):
temporalio.converter.encode_search_attribute_values([("bad type",)])
encode_search_attribute_values([("bad type",)])
with pytest.raises(ValueError, match="Timezone must be present"):
temporalio.converter.encode_search_attribute_values([datetime.utcnow()])
encode_search_attribute_values([datetime.utcnow()])
with pytest.raises(TypeError, match="must have the same type"):
temporalio.converter.encode_search_attribute_values(["foo", 123])
encode_search_attribute_values(["foo", 123])


def test_decode_search_attributes():
Expand All @@ -192,25 +200,23 @@ def payload(key, dtype, data, encoding=None):
return temporalio.api.common.v1.SearchAttributes(indexed_fields={key: check})

# Check basic keyword parsing works
kw_check = temporalio.converter.decode_search_attributes(
payload("kw", "Keyword", '"test-id"')
)
kw_check = decode_search_attributes(payload("kw", "Keyword", '"test-id"'))
assert kw_check["kw"][0] == "test-id"

# Ensure original DT functionality works
dt_check = temporalio.converter.decode_search_attributes(
dt_check = decode_search_attributes(
payload("dt", "Datetime", '"2020-01-01T00:00:00"')
)
assert dt_check["dt"][0] == datetime(2020, 1, 1, 0, 0, 0)

# Check timezone aware works as server is using ISO 8601
dttz_check = temporalio.converter.decode_search_attributes(
dttz_check = decode_search_attributes(
payload("dt", "Datetime", '"2020-01-01T00:00:00Z"')
)
assert dttz_check["dt"][0] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc)

# Check timezone aware, hour offset
dttz_check = temporalio.converter.decode_search_attributes(
dttz_check = decode_search_attributes(
payload("dt", "Datetime", '"2020-01-01T00:00:00+00:00"')
)
assert dttz_check["dt"][0] == datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
Expand Down Expand Up @@ -245,7 +251,7 @@ class MyPydanticClass(pydantic.BaseModel):


def test_json_type_hints():
converter = temporalio.converter.JSONPlainPayloadConverter()
converter = JSONPlainPayloadConverter()

def ok(
hint: Type, value: Any, expected_result: Any = temporalio.common._arg_unset
Expand Down Expand Up @@ -415,10 +421,8 @@ async def test_exception_format():

# Convert to failure and back
failure = Failure()
await temporalio.converter.DataConverter.default.encode_failure(actual_err, failure)
failure_error = await temporalio.converter.DataConverter.default.decode_failure(
failure
)
await DataConverter.default.encode_failure(actual_err, failure)
failure_error = await DataConverter.default.decode_failure(failure)
# Confirm type is prepended
assert isinstance(failure_error, ApplicationError)
assert "RuntimeError: error2" == str(failure_error)
Expand All @@ -440,3 +444,72 @@ async def test_exception_format():
logging.getLogger(__name__).debug(
"Showing appended exception", exc_info=failure_error
)


# Just serializes in a "payloads" wrapper
class SimpleCodec(PayloadCodec):
async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
wrapper = Payloads(payloads=payloads)
return [
Payload(
metadata={"simple-codec": b"true"}, data=wrapper.SerializeToString()
)
]

async def decode(self, payloads: Sequence[Payload]) -> List[Payload]:
payloads = list(payloads)
if len(payloads) != 1:
raise RuntimeError("Expected only a single payload")
elif payloads[0].metadata.get("simple-codec") != b"true":
raise RuntimeError("Not encoded with this codec")
wrapper = Payloads()
wrapper.ParseFromString(payloads[0].data)
return list(wrapper.payloads)


async def test_failure_encoded_attributes():
try:
raise ApplicationError("some message", "some detail")
except ApplicationError as err:
some_err = err

conv = DataConverter(
failure_converter_class=DefaultFailureConverterWithEncodedAttributes,
payload_codec=SimpleCodec(),
)

# Check failure
failure = Failure()
conv.failure_converter.to_failure(some_err, conv.payload_converter, failure)
assert failure.message == "Encoded failure"
assert failure.stack_trace == ""
assert conv.payload_converter.from_payloads(
failure.application_failure_info.details.payloads
) == ["some detail"]
encoded_attr = conv.payload_converter.from_payloads([failure.encoded_attributes])[0]
assert encoded_attr["message"] == "some message"
assert "test_converter" in encoded_attr["stack_trace"]

# Encode it and check encoded
orig_failure = Failure()
orig_failure.CopyFrom(failure)
await conv.payload_codec.encode_failure(failure)
assert "encoding" not in failure.encoded_attributes.metadata
assert "simple-codec" in failure.encoded_attributes.metadata
assert (
"encoding" not in failure.application_failure_info.details.payloads[0].metadata
)
assert (
"simple-codec" in failure.application_failure_info.details.payloads[0].metadata
)

# Decode and check
await conv.payload_codec.decode_failure(failure)
assert "encoding" in failure.encoded_attributes.metadata
assert "simple-codec" not in failure.encoded_attributes.metadata
assert "encoding" in failure.application_failure_info.details.payloads[0].metadata
assert (
"simple-codec"
not in failure.application_failure_info.details.payloads[0].metadata
)
assert failure == orig_failure