diff --git a/.gitignore b/.gitignore index 990f6517fe9..9f244805f60 100644 --- a/.gitignore +++ b/.gitignore @@ -314,4 +314,7 @@ examples/**/sam/.aws-sam cdk.out # NOTE: different accounts will be used for E2E thus creating unnecessary git clutter -cdk.context.json \ No newline at end of file +cdk.context.json + +# vim +*.swp diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index ffcef8b5096..747f48686c4 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -932,7 +932,8 @@ def analyze_param( ModelField | None The type annotation and the Pydantic field representing the parameter """ - field_info, type_annotation = get_field_info_and_type_annotation(annotation, value, is_path_param) + field_info, type_annotation = \ + get_field_info_and_type_annotation(annotation, value, is_path_param, is_response_param) # If the value is a FieldInfo, we use it as the FieldInfo for the parameter if isinstance(value, FieldInfo): @@ -962,7 +963,9 @@ def analyze_param( return field -def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -> tuple[FieldInfo | None, Any]: +def get_field_info_and_type_annotation( + annotation, value, is_path_param: bool, is_response_param: bool +) -> tuple[FieldInfo | None, Any]: """ Get the FieldInfo and type annotation from an annotation and value. """ @@ -976,6 +979,10 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) - # If the annotation is a Response type, we recursively call this function with the inner type elif get_origin(annotation) is Response: field_info, type_annotation = get_field_info_response_type(annotation, value) + # If the annotation is a tuple with two elements, we use the first element as the type annotation, + # just like we did in the APIGateway._to_response + elif is_response_param and get_origin(annotation) is tuple and len(get_args(annotation)) == 2: + field_info, type_annotation = get_field_info_tuple_type(annotation, value) # If the annotation is not an Annotated type, we use it as the type annotation else: type_annotation = annotation @@ -983,12 +990,22 @@ def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) - return field_info, type_annotation +def get_field_info_tuple_type(annotation, value) -> tuple[FieldInfo | None, Any]: + (inner_type, _) = get_args(annotation) + + # If the inner type is an Annotated type, we need to extract the type annotation and the FieldInfo + if get_origin(inner_type) is Annotated: + return get_field_info_annotated_type(inner_type, value, False) + + return None, inner_type + + def get_field_info_response_type(annotation, value) -> tuple[FieldInfo | None, Any]: # Example: get_args(Response[inner_type]) == (inner_type,) # noqa: ERA001 (inner_type,) = get_args(annotation) # Recursively resolve the inner type - return get_field_info_and_type_annotation(inner_type, value, False) + return get_field_info_and_type_annotation(inner_type, value, False, True) def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tuple[FieldInfo | None, Any]: diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index a57156db130..3c83d3eda72 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -172,6 +172,42 @@ def handler() -> Response[Annotated[str, Body(title="Response title")]]: assert response.schema_.type == "string" +def test_openapi_with_tuple_returns(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler() -> tuple[str, int]: + return "Hello, world", 200 + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses[200].content[JSON_CONTENT_TYPE] + assert response.schema_.title == "Return" + assert response.schema_.type == "string" + + +def test_openapi_with_tuple_annotated_returns(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler() -> tuple[Annotated[str, Body(title="Response title")], int]: + return "Hello, world", 200 + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses[200].content[JSON_CONTENT_TYPE] + assert response.schema_.title == "Response title" + assert response.schema_.type == "string" + + def test_openapi_with_omitted_param(): app = APIGatewayRestResolver()