From 31759ec6bd4389d05235ac21ffc84c8d683d4685 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 2 Nov 2023 15:24:55 +0100 Subject: [PATCH 01/13] feat(event_handler): add Bedrock Agent event handler --- .../event_handler/__init__.py | 2 + .../event_handler/api_gateway.py | 20 +++-- .../event_handler/bedrock_agent.py | 71 +++++++++++++++++ .../shared/headers_serializer.py | 9 +++ .../data_classes/bedrock_agent_event.py | 13 ++- .../event_handler/test_bedrock_agent.py | 79 +++++++++++++++++++ 6 files changed, 185 insertions(+), 9 deletions(-) create mode 100644 aws_lambda_powertools/event_handler/bedrock_agent.py create mode 100644 tests/functional/event_handler/test_bedrock_agent.py diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 7bdd9a97f72..ffbb2abe4ae 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -11,6 +11,7 @@ Response, ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver +from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, ) @@ -22,6 +23,7 @@ "APIGatewayHttpResolver", "ALBResolver", "ApiGatewayResolver", + "BedrockAgentResolver", "CORSConfig", "LambdaFunctionUrlResolver", "Response", diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 0ddf287f264..05a9691205b 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -45,6 +45,7 @@ ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2, + BedrockAgentEvent, LambdaFunctionUrlEvent, VPCLatticeEvent, VPCLatticeEventV2, @@ -85,6 +86,7 @@ class ProxyEventType(Enum): APIGatewayProxyEvent = "APIGatewayProxyEvent" APIGatewayProxyEventV2 = "APIGatewayProxyEventV2" ALBEvent = "ALBEvent" + BedrockAgentEvent = "BedrockAgentEvent" VPCLatticeEvent = "VPCLatticeEvent" VPCLatticeEventV2 = "VPCLatticeEventV2" LambdaFunctionUrlEvent = "LambdaFunctionUrlEvent" @@ -1315,6 +1317,7 @@ def __init__( self._strip_prefixes = strip_prefixes self.context: Dict = {} # early init as customers might add context before event resolution self.processed_stack_frames = [] + self.response_builder_class = ResponseBuilder # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) @@ -1784,7 +1787,7 @@ def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX): rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule) return re.compile(base_regex.format(rule_regex)) - def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: + def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: # noqa: PLR0911 """Convert the event dict to the corresponding data class""" if self._proxy_type == ProxyEventType.APIGatewayProxyEvent: logger.debug("Converting event to API Gateway REST API contract") @@ -1792,6 +1795,9 @@ def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2: logger.debug("Converting event to API Gateway HTTP API contract") return APIGatewayProxyEventV2(event) + if self._proxy_type == ProxyEventType.BedrockAgentEvent: + logger.debug("Converting event to Bedrock Agent contract") + return BedrockAgentEvent(event) if self._proxy_type == ProxyEventType.LambdaFunctionUrlEvent: logger.debug("Converting event to Lambda Function URL contract") return LambdaFunctionUrlEvent(event) @@ -1869,9 +1875,9 @@ def _not_found(self, method: str) -> ResponseBuilder: handler = self._lookup_exception_handler(NotFoundError) if handler: - return ResponseBuilder(handler(NotFoundError())) + return self.response_builder_class(handler(NotFoundError())) - return ResponseBuilder( + return self.response_builder_class( Response( status_code=HTTPStatus.NOT_FOUND.value, content_type=content_types.APPLICATION_JSON, @@ -1886,7 +1892,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response # Reset Processed stack for Middleware (for debugging purposes) self._reset_processed_stack() - return ResponseBuilder( + return self.response_builder_class( self._to_response( route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments), ), @@ -1903,7 +1909,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response # If the user has turned on debug mode, # we'll let the original exception propagate, so # they get more information about what went wrong. - return ResponseBuilder( + return self.response_builder_class( Response( status_code=500, content_type=content_types.TEXT_PLAIN, @@ -1942,12 +1948,12 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp handler = self._lookup_exception_handler(type(exp)) if handler: try: - return ResponseBuilder(handler(exp), route) + return self.response_builder_class(handler(exp), route) except ServiceError as service_error: exp = service_error if isinstance(exp, ServiceError): - return ResponseBuilder( + return self.response_builder_class( Response( status_code=exp.status_code, content_type=content_types.APPLICATION_JSON, diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py new file mode 100644 index 00000000000..7874f4c31f6 --- /dev/null +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -0,0 +1,71 @@ +import logging +from typing import Any, Dict, Optional, cast + +from aws_lambda_powertools.event_handler import ApiGatewayResolver +from aws_lambda_powertools.event_handler.api_gateway import CORSConfig, ProxyEventType, ResponseBuilder +from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent +from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent + +logger = logging.getLogger(__name__) + + +class BedrockResponseBuilder(ResponseBuilder): + def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: + """Build the full response dict to be returned by the lambda""" + self._route(event, cors) + + bedrock_event = cast(BedrockAgentEvent, event) + + return { + "messageVersion": "1.0", + "response": { + "actionGroup": bedrock_event.action_group, + "apiPath": bedrock_event.api_path, + "httpMethod": bedrock_event.http_method, + "httpStatusCode": self.response.status_code, + "responseBody": { + "application/json": { + "body": self.response.body, + }, + }, + }, + } + + +class BedrockAgentResolver(ApiGatewayResolver): + """Bedrock Agent Resolver + + See https://aws.amazon.com/bedrock/agents/ for more information. + + Examples + -------- + Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator + + ```python + from aws_lambda_powertools import Tracer + from aws_lambda_powertools.event_handler import BedrockAgentResolver + + tracer = Tracer() + app = BedrockAgentResolver() + + @app.get("/claims") + def simple_get(): + return "You have 3 claims" + + @tracer.capture_lambda_handler + def lambda_handler(event, context): + return app.resolve(event, context) + """ + + current_event: BedrockAgentEvent + + def __init__(self, debug: bool = False, enable_validation: bool = True): + super().__init__( + ProxyEventType.BedrockAgentEvent, + None, + debug, + None, + None, + enable_validation, + ) + self.response_builder_class = BedrockResponseBuilder diff --git a/aws_lambda_powertools/shared/headers_serializer.py b/aws_lambda_powertools/shared/headers_serializer.py index aa38157e26f..775134b57ef 100644 --- a/aws_lambda_powertools/shared/headers_serializer.py +++ b/aws_lambda_powertools/shared/headers_serializer.py @@ -123,3 +123,12 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Coo payload["headers"][key] = values[-1] return payload + + +class NoopSerializer(BaseHeadersSerializer): + """ + Noop serializer that doesn't do anything. This is useful for resolvers that don't need to set headers or cookies. + """ + + def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]: + return {} diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index b482b5b2b3e..2250d11e2e3 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional -from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer, NoopSerializer +from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper class BedrockAgentInfo(DictWrapper): @@ -47,7 +48,7 @@ def content(self) -> Dict[str, BedrockAgentRequestMedia]: return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()} -class BedrockAgentEvent(DictWrapper): +class BedrockAgentEvent(BaseProxyEvent): """ Bedrock Agent input event @@ -97,3 +98,11 @@ def session_attributes(self) -> Dict[str, str]: @property def prompt_session_attributes(self) -> Dict[str, str]: return self["promptSessionAttributes"] + + # For compatibility with BaseProxyEvent + @property + def path(self) -> str: + return self["apiPath"] + + def header_serializer(self) -> BaseHeadersSerializer: + return NoopSerializer() diff --git a/tests/functional/event_handler/test_bedrock_agent.py b/tests/functional/event_handler/test_bedrock_agent.py new file mode 100644 index 00000000000..c405915ed73 --- /dev/null +++ b/tests/functional/event_handler/test_bedrock_agent.py @@ -0,0 +1,79 @@ +import json +from typing import Any, Dict + +from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types +from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent +from tests.functional.utils import load_event + +claims_response = "You have 3 claims" + + +def test_bedrock_agent_event(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims") + def claims() -> Dict[str, Any]: + assert isinstance(app.current_event, BedrockAgentEvent) + assert app.lambda_context == {} + return {"output": claims_response} + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly + # AND set the current_event type as BedrockAgentEvent + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 200 + + body = result["response"]["responseBody"]["application/json"]["body"] + assert body == {"output": claims_response} + + +def test_bedrock_agent_event_with_response(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + output = json.dumps({"output": claims_response}) + + @app.get("/claims") + def claims(): + assert isinstance(app.current_event, BedrockAgentEvent) + assert app.lambda_context == {} + return Response(200, content_types.APPLICATION_JSON, output) + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly + # AND set the current_event type as BedrockAgentEvent + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 200 + + body = result["response"]["responseBody"]["application/json"]["body"] + assert body == output + + +def test_bedrock_agent_event_with_no_matches(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/no_match") + def claims(): + raise RuntimeError() + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly + # AND return 404 because the event doesn't match any known rule + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 404 From 3b6f070a048a51c3e93df13577b094dce8c794f1 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 15:13:45 +0100 Subject: [PATCH 02/13] Update aws_lambda_powertools/event_handler/api_gateway.py Co-authored-by: Heitor Lessa Signed-off-by: Ruben Fonseca --- aws_lambda_powertools/event_handler/api_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 05a9691205b..357adc2fc86 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1787,7 +1787,7 @@ def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX): rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule) return re.compile(base_regex.format(rule_regex)) - def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: # noqa: PLR0911 + def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: # noqa: PLR0911 # ignore many returns """Convert the event dict to the corresponding data class""" if self._proxy_type == ProxyEventType.APIGatewayProxyEvent: logger.debug("Converting event to API Gateway REST API contract") From 021f6c657228fe0b5f19d38ce6589d16405b1ba9 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 15:09:28 +0100 Subject: [PATCH 03/13] fix: tests --- .../event_handler/api_gateway.py | 16 ++++++++-------- .../event_handler/bedrock_agent.py | 5 ++++- .../event_handler/test_bedrock_agent.py | 6 +++--- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 357adc2fc86..dd45b1c79c9 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -210,7 +210,7 @@ def __init__( self, status_code: int, content_type: Optional[str] = None, - body: Union[str, bytes, None] = None, + body: Union[Any, None] = None, headers: Optional[Dict[str, Union[str, List[str]]]] = None, cookies: Optional[List[Cookie]] = None, compress: Optional[bool] = None, @@ -1317,7 +1317,7 @@ def __init__( self._strip_prefixes = strip_prefixes self.context: Dict = {} # early init as customers might add context before event resolution self.processed_stack_frames = [] - self.response_builder_class = ResponseBuilder + self._response_builder_class = ResponseBuilder # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) @@ -1875,9 +1875,9 @@ def _not_found(self, method: str) -> ResponseBuilder: handler = self._lookup_exception_handler(NotFoundError) if handler: - return self.response_builder_class(handler(NotFoundError())) + return self._response_builder_class(handler(NotFoundError())) - return self.response_builder_class( + return self._response_builder_class( Response( status_code=HTTPStatus.NOT_FOUND.value, content_type=content_types.APPLICATION_JSON, @@ -1892,7 +1892,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response # Reset Processed stack for Middleware (for debugging purposes) self._reset_processed_stack() - return self.response_builder_class( + return self._response_builder_class( self._to_response( route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments), ), @@ -1909,7 +1909,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response # If the user has turned on debug mode, # we'll let the original exception propagate, so # they get more information about what went wrong. - return self.response_builder_class( + return self._response_builder_class( Response( status_code=500, content_type=content_types.TEXT_PLAIN, @@ -1948,12 +1948,12 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp handler = self._lookup_exception_handler(type(exp)) if handler: try: - return self.response_builder_class(handler(exp), route) + return self._response_builder_class(handler(exp), route) except ServiceError as service_error: exp = service_error if isinstance(exp, ServiceError): - return self.response_builder_class( + return self._response_builder_class( Response( status_code=exp.status_code, content_type=content_types.APPLICATION_JSON, diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 7874f4c31f6..2706ca760a0 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -1,6 +1,8 @@ import logging from typing import Any, Dict, Optional, cast +from typing_extensions import override + from aws_lambda_powertools.event_handler import ApiGatewayResolver from aws_lambda_powertools.event_handler.api_gateway import CORSConfig, ProxyEventType, ResponseBuilder from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent @@ -10,6 +12,7 @@ class BedrockResponseBuilder(ResponseBuilder): + @override def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: """Build the full response dict to be returned by the lambda""" self._route(event, cors) @@ -68,4 +71,4 @@ def __init__(self, debug: bool = False, enable_validation: bool = True): None, enable_validation, ) - self.response_builder_class = BedrockResponseBuilder + self._response_builder_class = BedrockResponseBuilder diff --git a/tests/functional/event_handler/test_bedrock_agent.py b/tests/functional/event_handler/test_bedrock_agent.py index c405915ed73..de946542367 100644 --- a/tests/functional/event_handler/test_bedrock_agent.py +++ b/tests/functional/event_handler/test_bedrock_agent.py @@ -30,13 +30,13 @@ def claims() -> Dict[str, Any]: assert result["response"]["httpStatusCode"] == 200 body = result["response"]["responseBody"]["application/json"]["body"] - assert body == {"output": claims_response} + assert body == json.dumps({"output": claims_response}) def test_bedrock_agent_event_with_response(): # GIVEN a Bedrock Agent event app = BedrockAgentResolver() - output = json.dumps({"output": claims_response}) + output = {"output": claims_response} @app.get("/claims") def claims(): @@ -56,7 +56,7 @@ def claims(): assert result["response"]["httpStatusCode"] == 200 body = result["response"]["responseBody"]["application/json"]["body"] - assert body == output + assert body == json.dumps(output) def test_bedrock_agent_event_with_no_matches(): From a7f4aa3fdccee466a169f14bd0bb6c6e81e9d8f3 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 15:36:56 +0100 Subject: [PATCH 04/13] fix: use generics --- .../event_handler/api_gateway.py | 18 +++++++++++------- .../event_handler/bedrock_agent.py | 13 +++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index dd45b1c79c9..1523f554aaa 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -15,6 +15,7 @@ Any, Callable, Dict, + Generic, List, Match, Optional, @@ -23,6 +24,7 @@ Set, Tuple, Type, + TypeVar, Union, cast, ) @@ -63,6 +65,8 @@ _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response" _ROUTE_REGEX = "^{}$" +ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent) + if TYPE_CHECKING: from aws_lambda_powertools.event_handler.openapi.compat import ( JsonSchemaValue, @@ -691,14 +695,14 @@ def _generate_operation_id(self) -> str: return operation_id -class ResponseBuilder: +class ResponseBuilder(Generic[ResponseEventT]): """Internally used Response builder""" def __init__(self, response: Response, route: Optional[Route] = None): self.response = response self.route = route - def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig): + def _add_cors(self, event: ResponseEventT, cors: CORSConfig): """Update headers to include the configured Access-Control headers""" self.response.headers.update(cors.to_dict(event.get_header_value("Origin"))) @@ -711,7 +715,7 @@ def _add_cache_control(self, cache_control: str): def _has_compression_enabled( route_compression: bool, response_compression: Optional[bool], - event: BaseProxyEvent, + event: ResponseEventT, ) -> bool: """ Checks if compression is enabled. @@ -724,7 +728,7 @@ def _has_compression_enabled( A boolean indicating whether compression is enabled or not in the route setting. response_compression: bool, optional A boolean indicating whether compression is enabled or not in the response setting. - event: BaseProxyEvent + event: Generic[ResponseEventT] The event object containing the request details. Returns @@ -754,7 +758,7 @@ def _compress(self): gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) self.response.body = gzip.compress(self.response.body) + gzip.flush() - def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]): + def _route(self, event: ResponseEventT, cors: Optional[CORSConfig]): """Optionally handle any of the route's configure response handling""" if self.route is None: return @@ -769,7 +773,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]): ): self._compress() - def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: + def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: """Build the full response dict to be returned by the lambda""" self._route(event, cors) @@ -1317,7 +1321,7 @@ def __init__( self._strip_prefixes = strip_prefixes self.context: Dict = {} # early init as customers might add context before event resolution self.processed_stack_frames = [] - self._response_builder_class = ResponseBuilder + self._response_builder_class = ResponseBuilder[BaseProxyEvent] # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 2706ca760a0..38425ce2c2e 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -1,30 +1,27 @@ import logging -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional from typing_extensions import override from aws_lambda_powertools.event_handler import ApiGatewayResolver from aws_lambda_powertools.event_handler.api_gateway import CORSConfig, ProxyEventType, ResponseBuilder from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent -from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent logger = logging.getLogger(__name__) class BedrockResponseBuilder(ResponseBuilder): @override - def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: + def build(self, event: BedrockAgentEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: """Build the full response dict to be returned by the lambda""" self._route(event, cors) - bedrock_event = cast(BedrockAgentEvent, event) - return { "messageVersion": "1.0", "response": { - "actionGroup": bedrock_event.action_group, - "apiPath": bedrock_event.api_path, - "httpMethod": bedrock_event.http_method, + "actionGroup": event.action_group, + "apiPath": event.api_path, + "httpMethod": event.http_method, "httpStatusCode": self.response.status_code, "responseBody": { "application/json": { From 3810de544cf4bde59199b74e49e5a41ea168f4b8 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 15:46:20 +0100 Subject: [PATCH 05/13] chore: add docs --- aws_lambda_powertools/event_handler/bedrock_agent.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 38425ce2c2e..d5172b9f656 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -11,6 +11,12 @@ class BedrockResponseBuilder(ResponseBuilder): + """ + Bedrock Response Builder. This builds the response dict to be returned by the lambda when using Bedrock Agents. + + Since the payload format is different from the standard API Gateway Proxy event, we override the build method. + """ + @override def build(self, event: BedrockAgentEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: """Build the full response dict to be returned by the lambda""" From 7283182ae9de93221ff7a0f73b861471acec49ac Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 17:00:32 +0100 Subject: [PATCH 06/13] fix: addressed comments --- .../event_handler/api_gateway.py | 2 +- .../event_handler/bedrock_agent.py | 20 +++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1523f554aaa..1677b0989c4 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -214,7 +214,7 @@ def __init__( self, status_code: int, content_type: Optional[str] = None, - body: Union[Any, None] = None, + body: Any = None, headers: Optional[Dict[str, Union[str, List[str]]]] = None, cookies: Optional[List[Cookie]] = None, compress: Optional[bool] = None, diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index d5172b9f656..845f9819e76 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -1,10 +1,10 @@ import logging -from typing import Any, Dict, Optional +from typing import Any, Dict from typing_extensions import override from aws_lambda_powertools.event_handler import ApiGatewayResolver -from aws_lambda_powertools.event_handler.api_gateway import CORSConfig, ProxyEventType, ResponseBuilder +from aws_lambda_powertools.event_handler.api_gateway import ProxyEventType, ResponseBuilder from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent logger = logging.getLogger(__name__) @@ -18,9 +18,9 @@ class BedrockResponseBuilder(ResponseBuilder): """ @override - def build(self, event: BedrockAgentEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: + def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]: """Build the full response dict to be returned by the lambda""" - self._route(event, cors) + self._route(event, None) return { "messageVersion": "1.0", @@ -67,11 +67,11 @@ def lambda_handler(event, context): def __init__(self, debug: bool = False, enable_validation: bool = True): super().__init__( - ProxyEventType.BedrockAgentEvent, - None, - debug, - None, - None, - enable_validation, + proxy_type=ProxyEventType.BedrockAgentEvent, + cors=None, + debug=debug, + serializer=None, + strip_prefixes=None, + enable_validation=enable_validation, ) self._response_builder_class = BedrockResponseBuilder From 487dc372542007cb017c11f2d4f02edcc1c0f680 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 17:23:30 +0100 Subject: [PATCH 07/13] fix: remove unused noop serializer --- aws_lambda_powertools/shared/headers_serializer.py | 9 --------- .../utilities/data_classes/bedrock_agent_event.py | 4 ---- 2 files changed, 13 deletions(-) diff --git a/aws_lambda_powertools/shared/headers_serializer.py b/aws_lambda_powertools/shared/headers_serializer.py index 775134b57ef..aa38157e26f 100644 --- a/aws_lambda_powertools/shared/headers_serializer.py +++ b/aws_lambda_powertools/shared/headers_serializer.py @@ -123,12 +123,3 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Coo payload["headers"][key] = values[-1] return payload - - -class NoopSerializer(BaseHeadersSerializer): - """ - Noop serializer that doesn't do anything. This is useful for resolvers that don't need to set headers or cookies. - """ - - def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]: - return {} diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index 2250d11e2e3..1577ad62895 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -1,6 +1,5 @@ from typing import Dict, List, Optional -from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer, NoopSerializer from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper @@ -103,6 +102,3 @@ def prompt_session_attributes(self) -> Dict[str, str]: @property def path(self) -> str: return self["apiPath"] - - def header_serializer(self) -> BaseHeadersSerializer: - return NoopSerializer() From 65f4dfcbf55bce781e115dfa43e7f85049dc0742 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 17:24:29 +0100 Subject: [PATCH 08/13] Update aws_lambda_powertools/event_handler/bedrock_agent.py Co-authored-by: Heitor Lessa Signed-off-by: Ruben Fonseca --- aws_lambda_powertools/event_handler/bedrock_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 845f9819e76..ad5930ba94c 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -12,7 +12,7 @@ class BedrockResponseBuilder(ResponseBuilder): """ - Bedrock Response Builder. This builds the response dict to be returned by the lambda when using Bedrock Agents. + Bedrock Response Builder. This builds the response dict to be returned by Lambda when using Bedrock Agents. Since the payload format is different from the standard API Gateway Proxy event, we override the build method. """ From 73c0295f54e228c4efeb2ffdafe2da57f85a6c8b Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 17:25:20 +0100 Subject: [PATCH 09/13] Update aws_lambda_powertools/event_handler/api_gateway.py Co-authored-by: Heitor Lessa Signed-off-by: Ruben Fonseca --- aws_lambda_powertools/event_handler/api_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1677b0989c4..8f9598ed667 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -728,7 +728,7 @@ def _has_compression_enabled( A boolean indicating whether compression is enabled or not in the route setting. response_compression: bool, optional A boolean indicating whether compression is enabled or not in the response setting. - event: Generic[ResponseEventT] + event: ResponseEventT The event object containing the request details. Returns From a9a220eb7cde538c178d0bd5142cda2f0907cd62 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 17:42:27 +0100 Subject: [PATCH 10/13] fix: add crash test --- .../event_handler/api_gateway.py | 1 + .../event_handler/bedrock_agent.py | 2 +- .../event_handler/test_bedrock_agent.py | 31 +++++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 8f9598ed667..1e494fd1c0f 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -241,6 +241,7 @@ def __init__( self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {} self.cookies = cookies or [] self.compress = compress + self.content_type = content_type if content_type: self.headers.setdefault("Content-Type", content_type) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index ad5930ba94c..81805c398ff 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -30,7 +30,7 @@ def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]: "httpMethod": event.http_method, "httpStatusCode": self.response.status_code, "responseBody": { - "application/json": { + self.response.content_type: { "body": self.response.body, }, }, diff --git a/tests/functional/event_handler/test_bedrock_agent.py b/tests/functional/event_handler/test_bedrock_agent.py index de946542367..c85781b2442 100644 --- a/tests/functional/event_handler/test_bedrock_agent.py +++ b/tests/functional/event_handler/test_bedrock_agent.py @@ -77,3 +77,34 @@ def claims(): assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" assert result["response"]["httpMethod"] == "GET" assert result["response"]["httpStatusCode"] == 404 + + +def test_bedrock_agent_event_with_exception(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.exception_handler(RuntimeError) + def handle_runtime_error(ex: RuntimeError): + return Response( + status_code=500, + content_type=content_types.TEXT_PLAIN, + body="Something went wrong", + ) + + @app.get("/claims") + def claims(): + raise RuntimeError() + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process the exception correctly + # AND return 500 because of the internal server error + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 500 + + body = result["response"]["responseBody"]["text/plain"]["body"] + assert body == "Something went wrong" From ca6dd11991db614832ff69a736508d310a36b77f Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 17:44:55 +0100 Subject: [PATCH 11/13] chore: add validation error test --- .../event_handler/test_bedrock_agent.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/functional/event_handler/test_bedrock_agent.py b/tests/functional/event_handler/test_bedrock_agent.py index c85781b2442..ff919bacbb2 100644 --- a/tests/functional/event_handler/test_bedrock_agent.py +++ b/tests/functional/event_handler/test_bedrock_agent.py @@ -79,6 +79,29 @@ def claims(): assert result["response"]["httpStatusCode"] == 404 +def test_bedrock_agent_event_with_validation_error(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims") + def claims() -> Dict[str, Any]: + return "oh no, this is not a dict" # type: ignore + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly + # AND set the current_event type as BedrockAgentEvent + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 422 + + body = result["response"]["responseBody"]["application/json"]["body"] + assert "value is not a valid dict" in body + + def test_bedrock_agent_event_with_exception(): # GIVEN a Bedrock Agent event app = BedrockAgentResolver() From bd6087296010941d710b20549fe01439807d22c3 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 7 Nov 2023 17:46:27 +0100 Subject: [PATCH 12/13] fix: remove unused logger --- aws_lambda_powertools/event_handler/bedrock_agent.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 81805c398ff..258fc7dcaee 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -1,14 +1,14 @@ -import logging from typing import Any, Dict from typing_extensions import override from aws_lambda_powertools.event_handler import ApiGatewayResolver -from aws_lambda_powertools.event_handler.api_gateway import ProxyEventType, ResponseBuilder +from aws_lambda_powertools.event_handler.api_gateway import ( + ProxyEventType, + ResponseBuilder, +) from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent -logger = logging.getLogger(__name__) - class BedrockResponseBuilder(ResponseBuilder): """ From fd7c974e2dc24581f5e7cc19d2beab9b77595b95 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Wed, 8 Nov 2023 09:18:23 +0100 Subject: [PATCH 13/13] fix: tests --- tests/functional/event_handler/test_bedrock_agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/functional/event_handler/test_bedrock_agent.py b/tests/functional/event_handler/test_bedrock_agent.py index ff919bacbb2..dcdca460d25 100644 --- a/tests/functional/event_handler/test_bedrock_agent.py +++ b/tests/functional/event_handler/test_bedrock_agent.py @@ -2,6 +2,7 @@ from typing import Any, Dict from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types +from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2 from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent from tests.functional.utils import load_event @@ -99,7 +100,10 @@ def claims() -> Dict[str, Any]: assert result["response"]["httpStatusCode"] == 422 body = result["response"]["responseBody"]["application/json"]["body"] - assert "value is not a valid dict" in body + if PYDANTIC_V2: + assert "should be a valid dictionary" in body + else: + assert "value is not a valid dict" in body def test_bedrock_agent_event_with_exception():