Skip to content

Commit 4bab2db

Browse files
feat(kafka): add logic to handle protobuf deserialization (#6841)
Fixing Glue + Confluent + Plain protobuf deser
1 parent f53bc27 commit 4bab2db

File tree

18 files changed

+431
-119
lines changed

18 files changed

+431
-119
lines changed

aws_lambda_powertools/utilities/kafka/consumer_records.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
from functools import cached_property
45
from typing import TYPE_CHECKING, Any
56

@@ -13,6 +14,8 @@
1314

1415
from aws_lambda_powertools.utilities.kafka.schema_config import SchemaConfig
1516

17+
logger = logging.getLogger(__name__)
18+
1619

1720
class ConsumerRecordRecords(KafkaEventRecordBase):
1821
"""
@@ -31,18 +34,24 @@ def key(self) -> Any:
3134
if not key:
3235
return None
3336

37+
logger.debug("Deserializing key field")
38+
3439
# Determine schema type and schema string
3540
schema_type = None
36-
schema_str = None
41+
schema_value = None
3742
output_serializer = None
3843

3944
if self.schema_config and self.schema_config.key_schema_type:
4045
schema_type = self.schema_config.key_schema_type
41-
schema_str = self.schema_config.key_schema
46+
schema_value = self.schema_config.key_schema
4247
output_serializer = self.schema_config.key_output_serializer
4348

4449
# Always use get_deserializer if None it will default to DEFAULT
45-
deserializer = get_deserializer(schema_type, schema_str)
50+
deserializer = get_deserializer(
51+
schema_type=schema_type,
52+
schema_value=schema_value,
53+
field_metadata=self.key_schema_metadata,
54+
)
4655
deserialized_value = deserializer.deserialize(key)
4756

4857
# Apply output serializer if specified
@@ -57,16 +66,22 @@ def value(self) -> Any:
5766

5867
# Determine schema type and schema string
5968
schema_type = None
60-
schema_str = None
69+
schema_value = None
6170
output_serializer = None
6271

72+
logger.debug("Deserializing value field")
73+
6374
if self.schema_config and self.schema_config.value_schema_type:
6475
schema_type = self.schema_config.value_schema_type
65-
schema_str = self.schema_config.value_schema
76+
schema_value = self.schema_config.value_schema
6677
output_serializer = self.schema_config.value_output_serializer
6778

6879
# Always use get_deserializer if None it will default to DEFAULT
69-
deserializer = get_deserializer(schema_type, schema_str)
80+
deserializer = get_deserializer(
81+
schema_type=schema_type,
82+
schema_value=schema_value,
83+
field_metadata=self.value_schema_metadata,
84+
)
7085
deserialized_value = deserializer.deserialize(value)
7186

7287
# Apply output serializer if specified

aws_lambda_powertools/utilities/kafka/deserializer/avro.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import io
4+
import logging
5+
from typing import Any
46

57
from avro.io import BinaryDecoder, DatumReader
68
from avro.schema import parse as parse_schema
@@ -9,8 +11,11 @@
911
from aws_lambda_powertools.utilities.kafka.exceptions import (
1012
KafkaConsumerAvroSchemaParserError,
1113
KafkaConsumerDeserializationError,
14+
KafkaConsumerDeserializationFormatMismatch,
1215
)
1316

17+
logger = logging.getLogger(__name__)
18+
1419

1520
class AvroDeserializer(DeserializerBase):
1621
"""
@@ -20,10 +25,11 @@ class AvroDeserializer(DeserializerBase):
2025
a provided Avro schema definition.
2126
"""
2227

23-
def __init__(self, schema_str: str):
28+
def __init__(self, schema_str: str, field_metadata: dict[str, Any] | None = None):
2429
try:
2530
self.parsed_schema = parse_schema(schema_str)
2631
self.reader = DatumReader(self.parsed_schema)
32+
self.field_metatada = field_metadata
2733
except Exception as e:
2834
raise KafkaConsumerAvroSchemaParserError(
2935
f"Invalid Avro schema. Please ensure the provided avro schema is valid: {type(e).__name__}: {str(e)}",
@@ -60,6 +66,13 @@ def deserialize(self, data: bytes | str) -> object:
6066
... except KafkaConsumerDeserializationError as e:
6167
... print(f"Failed to deserialize: {e}")
6268
"""
69+
data_format = self.field_metatada.get("dataFormat") if self.field_metatada else None
70+
71+
if data_format and data_format != "AVRO":
72+
raise KafkaConsumerDeserializationFormatMismatch(f"Expected data is AVRO but you sent {data_format}")
73+
74+
logger.debug("Deserializing data with AVRO format")
75+
6376
try:
6477
value = self._decode_input(data)
6578
bytes_reader = io.BytesIO(value)

aws_lambda_powertools/utilities/kafka/deserializer/default.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

33
import base64
4+
import logging
45

56
from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase
67

8+
logger = logging.getLogger(__name__)
9+
710

811
class DefaultDeserializer(DeserializerBase):
912
"""
@@ -43,4 +46,5 @@ def deserialize(self, data: bytes | str) -> str:
4346
>>> result = deserializer.deserialize(bytes_data)
4447
>>> print(result == bytes_data) # Output: True
4548
"""
49+
logger.debug("Deserializing data with primitives types")
4650
return base64.b64decode(data).decode("utf-8")

aws_lambda_powertools/utilities/kafka/deserializer/deserializer.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,27 @@
1313
_deserializer_cache: dict[str, DeserializerBase] = {}
1414

1515

16-
def _get_cache_key(schema_type: str | object, schema_value: Any) -> str:
16+
def _get_cache_key(schema_type: str | object, schema_value: Any, field_metadata: dict[str, Any]) -> str:
17+
schema_metadata = None
18+
19+
if field_metadata:
20+
schema_metadata = field_metadata.get("schemaId")
21+
1722
if schema_value is None:
18-
return str(schema_type)
23+
schema_hash = f"{str(schema_type)}_{schema_metadata}"
1924

2025
if isinstance(schema_value, str):
26+
hashable_value = f"{schema_value}_{schema_metadata}"
2127
# For string schemas like Avro, hash the content
22-
schema_hash = hashlib.md5(schema_value.encode("utf-8"), usedforsecurity=False).hexdigest()
28+
schema_hash = hashlib.md5(hashable_value.encode("utf-8"), usedforsecurity=False).hexdigest()
2329
else:
2430
# For objects like Protobuf, use the object id
25-
schema_hash = str(id(schema_value))
31+
schema_hash = f"{str(id(schema_value))}_{schema_metadata}"
2632

2733
return f"{schema_type}_{schema_hash}"
2834

2935

30-
def get_deserializer(schema_type: str | object, schema_value: Any) -> DeserializerBase:
36+
def get_deserializer(schema_type: str | object, schema_value: Any, field_metadata: Any) -> DeserializerBase:
3137
"""
3238
Factory function to get the appropriate deserializer based on schema type.
3339
@@ -75,7 +81,7 @@ def get_deserializer(schema_type: str | object, schema_value: Any) -> Deserializ
7581
"""
7682

7783
# Generate a cache key based on schema type and value
78-
cache_key = _get_cache_key(schema_type, schema_value)
84+
cache_key = _get_cache_key(schema_type, schema_value, field_metadata)
7985

8086
# Check if we already have this deserializer in cache
8187
if cache_key in _deserializer_cache:
@@ -87,14 +93,14 @@ def get_deserializer(schema_type: str | object, schema_value: Any) -> Deserializ
8793
# Import here to avoid dependency if not used
8894
from aws_lambda_powertools.utilities.kafka.deserializer.avro import AvroDeserializer
8995

90-
deserializer = AvroDeserializer(schema_value)
96+
deserializer = AvroDeserializer(schema_str=schema_value, field_metadata=field_metadata)
9197
elif schema_type == "PROTOBUF":
9298
# Import here to avoid dependency if not used
9399
from aws_lambda_powertools.utilities.kafka.deserializer.protobuf import ProtobufDeserializer
94100

95-
deserializer = ProtobufDeserializer(schema_value)
101+
deserializer = ProtobufDeserializer(message_class=schema_value, field_metadata=field_metadata)
96102
elif schema_type == "JSON":
97-
deserializer = JsonDeserializer()
103+
deserializer = JsonDeserializer(field_metadata=field_metadata)
98104

99105
else:
100106
# Default to no-op deserializer

aws_lambda_powertools/utilities/kafka/deserializer/json.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@
22

33
import base64
44
import json
5+
import logging
6+
from typing import Any
57

68
from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase
7-
from aws_lambda_powertools.utilities.kafka.exceptions import KafkaConsumerDeserializationError
9+
from aws_lambda_powertools.utilities.kafka.exceptions import (
10+
KafkaConsumerDeserializationError,
11+
KafkaConsumerDeserializationFormatMismatch,
12+
)
13+
14+
logger = logging.getLogger(__name__)
815

916

1017
class JsonDeserializer(DeserializerBase):
@@ -15,6 +22,9 @@ class JsonDeserializer(DeserializerBase):
1522
into Python dictionaries.
1623
"""
1724

25+
def __init__(self, field_metadata: dict[str, Any] | None = None):
26+
self.field_metatada = field_metadata
27+
1828
def deserialize(self, data: bytes | str) -> dict:
1929
"""
2030
Deserialize JSON data to a Python dictionary.
@@ -45,6 +55,14 @@ def deserialize(self, data: bytes | str) -> dict:
4555
... except KafkaConsumerDeserializationError as e:
4656
... print(f"Failed to deserialize: {e}")
4757
"""
58+
59+
data_format = self.field_metatada.get("dataFormat") if self.field_metatada else None
60+
61+
if data_format and data_format != "JSON":
62+
raise KafkaConsumerDeserializationFormatMismatch(f"Expected data is JSON but you sent {data_format}")
63+
64+
logger.debug("Deserializing data with JSON format")
65+
4866
try:
4967
return json.loads(base64.b64decode(data).decode("utf-8"))
5068
except Exception as e:
Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from __future__ import annotations
22

3+
import logging
34
from typing import Any
45

5-
from google.protobuf.internal.decoder import _DecodeVarint # type: ignore[attr-defined]
6+
from google.protobuf.internal.decoder import _DecodeSignedVarint # type: ignore[attr-defined]
67
from google.protobuf.json_format import MessageToDict
78

89
from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase
910
from aws_lambda_powertools.utilities.kafka.exceptions import (
1011
KafkaConsumerDeserializationError,
12+
KafkaConsumerDeserializationFormatMismatch,
1113
)
1214

15+
logger = logging.getLogger(__name__)
16+
1317

1418
class ProtobufDeserializer(DeserializerBase):
1519
"""
@@ -19,8 +23,9 @@ class ProtobufDeserializer(DeserializerBase):
1923
into Python dictionaries using the provided Protocol Buffer message class.
2024
"""
2125

22-
def __init__(self, message_class: Any):
26+
def __init__(self, message_class: Any, field_metadata: dict[str, Any] | None = None):
2327
self.message_class = message_class
28+
self.field_metatada = field_metadata
2429

2530
def deserialize(self, data: bytes | str) -> dict:
2631
"""
@@ -61,57 +66,56 @@ def deserialize(self, data: bytes | str) -> dict:
6166
... except KafkaConsumerDeserializationError as e:
6267
... print(f"Failed to deserialize: {e}")
6368
"""
64-
value = self._decode_input(data)
65-
try:
66-
message = self.message_class()
67-
message.ParseFromString(value)
68-
return MessageToDict(message, preserving_proto_field_name=True)
69-
except Exception:
70-
return self._deserialize_with_message_index(value, self.message_class())
7169

72-
def _deserialize_with_message_index(self, data: bytes, parser: Any) -> dict:
73-
"""
74-
Deserialize protobuf message with Confluent message index handling.
70+
data_format = self.field_metatada.get("dataFormat") if self.field_metatada else None
71+
schema_id = self.field_metatada.get("schemaId") if self.field_metatada else None
7572

76-
Parameters
77-
----------
78-
data : bytes
79-
data
80-
parser : google.protobuf.message.Message
81-
Protobuf message instance to parse the data into
73+
if data_format and data_format != "PROTOBUF":
74+
raise KafkaConsumerDeserializationFormatMismatch(f"Expected data is PROTOBUF but you sent {data_format}")
8275

83-
Returns
84-
-------
85-
dict
86-
Dictionary representation of the parsed protobuf message with original field names
76+
logger.debug("Deserializing data with PROTOBUF format")
8777

88-
Raises
89-
------
90-
KafkaConsumerDeserializationError
91-
If deserialization fails
78+
try:
79+
value = self._decode_input(data)
80+
message = self.message_class()
81+
if schema_id is None:
82+
logger.debug("Plain PROTOBUF data: using default deserializer")
83+
# Plain protobuf - direct parser
84+
message.ParseFromString(value)
85+
elif len(schema_id) > 20:
86+
logger.debug("PROTOBUF data integrated with Glue SchemaRegistry: using Glue deserializer")
87+
# Glue schema registry integration - remove the first byte
88+
message.ParseFromString(value[1:])
89+
else:
90+
logger.debug("PROTOBUF data integrated with Confluent SchemaRegistry: using Confluent deserializer")
91+
# Confluent schema registry integration - remove message index list
92+
message.ParseFromString(self._remove_message_index(value))
9293

93-
Notes
94-
-----
95-
This method handles the special case of Confluent Schema Registry's message index
96-
format, where the message is prefixed with either a single 0 (for the first schema)
97-
or a list of schema indexes. The actual protobuf message follows these indexes.
98-
"""
94+
return MessageToDict(message, preserving_proto_field_name=True)
95+
except Exception as e:
96+
raise KafkaConsumerDeserializationError(
97+
f"Error trying to deserialize protobuf data - {type(e).__name__}: {str(e)}",
98+
) from e
9999

100+
def _remove_message_index(self, data):
101+
"""
102+
Identifies and removes Confluent Schema Registry MessageIndex from bytes.
103+
Returns pure protobuf bytes.
104+
"""
100105
buffer = memoryview(data)
101106
pos = 0
102107

103-
try:
104-
first_value, new_pos = _DecodeVarint(buffer, pos)
105-
pos = new_pos
108+
logger.debug("Removing message list bytes")
106109

107-
if first_value != 0:
108-
for _ in range(first_value):
109-
_, new_pos = _DecodeVarint(buffer, pos)
110-
pos = new_pos
110+
# Read first varint (index count or 0)
111+
first_value, new_pos = _DecodeSignedVarint(buffer, pos)
112+
pos = new_pos
111113

112-
parser.ParseFromString(data[pos:])
113-
return MessageToDict(parser, preserving_proto_field_name=True)
114-
except Exception as e:
115-
raise KafkaConsumerDeserializationError(
116-
f"Error trying to deserialize protobuf data - {type(e).__name__}: {str(e)}",
117-
) from e
114+
# Skip index values if present
115+
if first_value != 0:
116+
for _ in range(first_value):
117+
_, new_pos = _DecodeSignedVarint(buffer, pos)
118+
pos = new_pos
119+
120+
# Return remaining bytes (pure protobuf)
121+
return data[pos:]

aws_lambda_powertools/utilities/kafka/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ class KafkaConsumerAvroSchemaParserError(Exception):
44
"""
55

66

7+
class KafkaConsumerDeserializationFormatMismatch(Exception):
8+
"""
9+
Error raised when deserialization format is incompatible
10+
"""
11+
12+
713
class KafkaConsumerDeserializationError(Exception):
814
"""
915
Error raised when message deserialization fails.

0 commit comments

Comments
 (0)