Skip to content

fix(event-handler): enable path parameters on Bedrock handler #3312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,7 +1829,8 @@ def _resolve(self) -> ResponseBuilder:
# Add matched Route reference into the Resolver context
self.append_context(_route=route, _path=path)

return self._call_route(route, match_results.groupdict()) # pass fn args
route_keys = self._convert_matches_into_route_keys(match_results)
return self._call_route(route, route_keys) # pass fn args

logger.debug(f"No match found for path {path} and method {method}")
return self._not_found(method)
Expand Down Expand Up @@ -1858,6 +1859,10 @@ def _remove_prefix(self, path: str) -> str:

return path

def _convert_matches_into_route_keys(self, match: Match) -> Dict[str, str]:
"""Converts the regex match into a dict of route keys"""
return match.groupdict()

@staticmethod
def _path_starts_with(path: str, prefix: str):
"""Returns true if the `path` starts with a prefix plus a `/`"""
Expand Down
10 changes: 10 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from re import Match
from typing import Any, Dict

from typing_extensions import override
Expand Down Expand Up @@ -75,3 +76,12 @@ def __init__(self, debug: bool = False, enable_validation: bool = True):
enable_validation=enable_validation,
)
self._response_builder_class = BedrockResponseBuilder

@override
def _convert_matches_into_route_keys(self, match: Match) -> Dict[str, str]:
# In Bedrock Agents, all the parameters come inside the "parameters" key, not on the apiPath
# So we have to search for route parameters in the parameters key
parameters: Dict[str, str] = {}
if match.groupdict() and self.current_event.parameters:
parameters = {parameter["name"]: parameter["value"] for parameter in self.current_event.parameters}
return parameters
2 changes: 1 addition & 1 deletion aws_lambda_powertools/event_handler/openapi/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
DEFAULT_API_VERSION = "1.0.0"
DEFAULT_OPENAPI_VERSION = "3.1.0"
DEFAULT_OPENAPI_VERSION = "3.0.0"
18 changes: 10 additions & 8 deletions aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Dependant,
File,
Form,
Header,
Param,
ParamTypes,
Query,
_File,
_Form,
_Header,
analyze_param,
create_response_field,
get_flat_dependant,
Expand Down Expand Up @@ -235,7 +235,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
return False
elif is_scalar_field(field=param_field):
return False
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
elif isinstance(param_field.field_info, (Query, _Header)) and is_scalar_sequence_field(param_field):
return False
else:
if not isinstance(param_field.field_info, Body):
Expand Down Expand Up @@ -326,10 +326,12 @@ def get_body_field_info(
if not required:
body_field_info_kwargs["default"] = None

if any(isinstance(f.field_info, File) for f in flat_dependant.body_params):
body_field_info: Type[Body] = File
elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params):
body_field_info = Form
if any(isinstance(f.field_info, _File) for f in flat_dependant.body_params):
# MAINTENANCE: body_field_info: Type[Body] = _File
raise NotImplementedError("_File fields are not supported in request bodies")
elif any(isinstance(f.field_info, _Form) for f in flat_dependant.body_params):
# MAINTENANCE: body_field_info: Type[Body] = _Form
raise NotImplementedError("_Form fields are not supported in request bodies")
else:
body_field_info = Body

Expand Down
6 changes: 3 additions & 3 deletions aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def __init__(
)


class Header(Param):
class _Header(Param):
"""
A class used internally to represent a header parameter in a path operation.
"""
Expand Down Expand Up @@ -471,7 +471,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"


class Form(Body):
class _Form(Body):
"""
A class used internally to represent a form parameter in a path operation.
"""
Expand Down Expand Up @@ -543,7 +543,7 @@ def __init__(
)


class File(Form):
class _File(_Form):
"""
A class used internally to represent a file parameter in a path operation.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ def session_attributes(self) -> Dict[str, str]:
def prompt_session_attributes(self) -> Dict[str, str]:
return self["promptSessionAttributes"]

# For compatibility with BaseProxyEvent
# The following methods add compatibility with BaseProxyEvent
@property
def path(self) -> str:
return self["apiPath"]

@property
def query_string_parameters(self) -> Optional[Dict[str, str]]:
# In Bedrock Agent events, query string parameters are passed as undifferentiated parameters,
# together with the other parameters. So we just return all parameters here.
return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None
23 changes: 23 additions & 0 deletions tests/events/bedrockAgentEventWithPathParams.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"actionGroup": "ClaimManagementActionGroup",
"messageVersion": "1.0",
"sessionId": "12345678912345",
"sessionAttributes": {},
"promptSessionAttributes": {},
"inputText": "I want to claim my insurance",
"agent": {
"alias": "TSTALIASID",
"name": "test",
"version": "DRAFT",
"id": "8ZXY0W8P1H"
},
"parameters": [
{
"type": "string",
"name": "claim_id",
"value": "123"
}
],
"httpMethod": "GET",
"apiPath": "/claims/<claim_id>"
}
22 changes: 22 additions & 0 deletions tests/functional/event_handler/test_bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,28 @@ def claims() -> Dict[str, Any]:
assert body == json.dumps({"output": claims_response})


def test_bedrock_agent_with_path_params():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()

@app.get("/claims/<claim_id>")
def claims(claim_id: str):
assert isinstance(app.current_event, BedrockAgentEvent)
assert app.lambda_context == {}
assert claim_id == "123"

# WHEN calling the event handler
result = app(load_event("bedrockAgentEventWithPathParams.json"), {})

# THEN process event correctly
# AND set the current_event type as BedrockAgentEvent
assert result["messageVersion"] == "1.0"
assert result["response"]["apiPath"] == "/claims/<claim_id>"
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
assert result["response"]["httpMethod"] == "GET"
assert result["response"]["httpStatusCode"] == 200


def test_bedrock_agent_event_with_response():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/event_handler/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
)
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Header,
Param,
ParamTypes,
Query,
_create_model_field,
_Header,
)
from aws_lambda_powertools.shared.types import Annotated

Expand Down Expand Up @@ -375,7 +375,7 @@ def secret():


def test_create_header():
header = Header(convert_underscores=True)
header = _Header(convert_underscores=True)
assert header.convert_underscores is True


Expand All @@ -400,7 +400,7 @@ def test_create_model_field_with_empty_in():

# Tests that when we try to create a model field with convert_underscore, we convert the field name
def test_create_model_field_convert_underscore():
field_info = Header(alias=None, convert_underscores=True)
field_info = _Header(alias=None, convert_underscores=True)

result = _create_model_field(field_info, int, "user_id", False)
assert result.alias == "user-id"