Skip to content

Commit 2ea9751

Browse files
fix(event_source): fix decode headers with signed bytes (#6878)
Fix decoding headers signed numbers
1 parent 7d981ff commit 2ea9751

File tree

7 files changed

+55
-14
lines changed

7 files changed

+55
-14
lines changed

aws_lambda_powertools/shared/functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,19 @@ def sanitize_xray_segment_name(name: str) -> str:
291291
def get_tracer_id() -> str | None:
292292
xray_trace_id = os.getenv(constants.XRAY_TRACE_ID_ENV)
293293
return xray_trace_id.split(";")[0].replace("Root=", "") if xray_trace_id else None
294+
295+
296+
def decode_header_bytes(byte_list):
297+
"""
298+
Decode a list of byte values that might be signed.
299+
If any negative values exist, handle them as signed bytes.
300+
Otherwise use the normal bytes construction.
301+
"""
302+
has_negative = any(b < 0 for b in byte_list)
303+
304+
if not has_negative:
305+
# Use normal bytes construction if all values are positive
306+
return bytes(byte_list)
307+
# Convert signed bytes to unsigned (0-255 range)
308+
unsigned_bytes = [(b & 0xFF) for b in byte_list]
309+
return bytes(unsigned_bytes)

aws_lambda_powertools/utilities/data_classes/kafka_event.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import cached_property
55
from typing import TYPE_CHECKING, Any
66

7+
from aws_lambda_powertools.shared.functions import decode_header_bytes
78
from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict, DictWrapper
89

910
if TYPE_CHECKING:
@@ -110,7 +111,7 @@ def headers(self) -> list[dict[str, list[int]]]:
110111
@cached_property
111112
def decoded_headers(self) -> dict[str, bytes]:
112113
"""Decodes the headers as a single dictionary."""
113-
return CaseInsensitiveDict((k, bytes(v)) for chunk in self.headers for k, v in chunk.items())
114+
return CaseInsensitiveDict((k, decode_header_bytes(v)) for chunk in self.headers for k, v in chunk.items())
114115

115116

116117
class KafkaEventBase(DictWrapper):

aws_lambda_powertools/utilities/kafka/consumer_records.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from functools import cached_property
55
from typing import TYPE_CHECKING, Any
66

7+
from aws_lambda_powertools.shared.functions import decode_header_bytes
78
from aws_lambda_powertools.utilities.data_classes.common import CaseInsensitiveDict
89
from aws_lambda_powertools.utilities.data_classes.kafka_event import KafkaEventBase, KafkaEventRecordBase
910
from aws_lambda_powertools.utilities.kafka.deserializer.deserializer import get_deserializer
@@ -115,7 +116,9 @@ def original_headers(self) -> list[dict[str, list[int]]]:
115116
@cached_property
116117
def headers(self) -> dict[str, bytes]:
117118
"""Decodes the headers as a single dictionary."""
118-
return CaseInsensitiveDict((k, bytes(v)) for chunk in self.original_headers for k, v in chunk.items())
119+
return CaseInsensitiveDict(
120+
(k, decode_header_bytes(v)) for chunk in self.original_headers for k, v in chunk.items()
121+
)
119122

120123

121124
class ConsumerRecords(KafkaEventBase):

aws_lambda_powertools/utilities/parser/models/kafka.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pydantic import BaseModel, field_validator
55

6-
from aws_lambda_powertools.shared.functions import base64_decode, bytes_to_string
6+
from aws_lambda_powertools.shared.functions import base64_decode, bytes_to_string, decode_header_bytes
77

88
SERVERS_DELIMITER = ","
99

@@ -28,9 +28,7 @@ class KafkaRecordModel(BaseModel):
2828
# key is optional; only decode if not None
2929
@field_validator("key", mode="before")
3030
def decode_key(cls, value):
31-
if value is not None:
32-
return base64_decode(value)
33-
return value
31+
return base64_decode(value) if value is not None else value
3432

3533
@field_validator("value", mode="before")
3634
def data_base64_decode(cls, value):
@@ -41,7 +39,7 @@ def data_base64_decode(cls, value):
4139
def decode_headers_list(cls, value):
4240
for header in value:
4341
for key, values in header.items():
44-
header[key] = bytes(values)
42+
header[key] = decode_header_bytes(values)
4543
return value
4644

4745

@@ -51,7 +49,7 @@ class KafkaBaseEventModel(BaseModel):
5149

5250
@field_validator("bootstrapServers", mode="before")
5351
def split_servers(cls, value):
54-
return None if not value else value.split(SERVERS_DELIMITER)
52+
return value.split(SERVERS_DELIMITER) if value else None
5553

5654

5755
class KafkaSelfManagedEventModel(KafkaBaseEventModel):

tests/events/kafkaEventMsk.json

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,28 @@
104104
"dataFormat": "AVRO",
105105
"schemaId": "1234"
106106
}
107+
},
108+
{
109+
"topic":"mymessage-with-unsigned",
110+
"partition":0,
111+
"offset":15,
112+
"timestamp":1545084650987,
113+
"timestampType":"CREATE_TIME",
114+
"key": null,
115+
"value":"eyJrZXkiOiJ2YWx1ZSJ9",
116+
"headers":[
117+
{
118+
"headerKey":[104, 101, 108, 108, 111, 45, 119, 111, 114, 108, 100, 45, -61, -85]
119+
}
120+
],
121+
"valueSchemaMetadata": {
122+
"dataFormat": "AVRO",
123+
"schemaId": "1234"
124+
},
125+
"keySchemaMetadata": {
126+
"dataFormat": "AVRO",
127+
"schemaId": "1234"
128+
}
107129
}
108130
]
109131
}

tests/unit/data_classes/required_dependencies/test_kafka_event.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_kafka_msk_event():
2121
assert parsed_event.decoded_bootstrap_servers == bootstrap_servers_list
2222

2323
records = list(parsed_event.records)
24-
assert len(records) == 3
24+
assert len(records) == 4
2525
record = records[0]
2626
raw_record = raw_event["records"]["mytopic-0"][0]
2727
assert record.topic == raw_record["topic"]
@@ -40,9 +40,10 @@ def test_kafka_msk_event():
4040
assert record.value_schema_metadata.schema_id == raw_record["valueSchemaMetadata"]["schemaId"]
4141

4242
assert parsed_event.record == records[0]
43-
for i in range(1, 3):
43+
for i in range(1, 4):
4444
record = records[i]
4545
assert record.key is None
46+
assert record.decoded_headers is not None
4647

4748

4849
def test_kafka_self_managed_event():
@@ -90,5 +91,5 @@ def test_kafka_record_property_with_stopiteration_error():
9091
# WHEN calling record property thrice
9192
# THEN raise StopIteration
9293
with pytest.raises(StopIteration):
93-
for _ in range(4):
94+
for _ in range(5):
9495
assert parsed_event.record.topic is not None

tests/unit/parser/_pydantic/test_kafka.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_kafka_msk_event_with_envelope():
1717
)
1818
for i in range(3):
1919
assert parsed_event[i].key == "value"
20-
assert len(parsed_event) == 3
20+
assert len(parsed_event) == 4
2121

2222

2323
def test_kafka_self_managed_event_with_envelope():
@@ -70,7 +70,7 @@ def test_kafka_msk_event():
7070
assert parsed_event.eventSourceArn == raw_event["eventSourceArn"]
7171

7272
records = list(parsed_event.records["mytopic-0"])
73-
assert len(records) == 3
73+
assert len(records) == 4
7474
record: KafkaRecordModel = records[0]
7575
raw_record = raw_event["records"]["mytopic-0"][0]
7676
assert record.topic == raw_record["topic"]
@@ -88,6 +88,6 @@ def test_kafka_msk_event():
8888
assert record.keySchemaMetadata.schemaId == "1234"
8989
assert record.valueSchemaMetadata.dataFormat == "AVRO"
9090
assert record.valueSchemaMetadata.schemaId == "1234"
91-
for i in range(1, 3):
91+
for i in range(1, 4):
9292
record: KafkaRecordModel = records[i]
9393
assert record.key is None

0 commit comments

Comments
 (0)