Skip to content

Commit

Permalink
refactor: improve http response abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
seedspirit committed Jan 31, 2025
1 parent 576ed8d commit 9519d29
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 61 deletions.
46 changes: 31 additions & 15 deletions src/ai/backend/common/pydantic_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,22 @@ def from_request(cls, request: web.Request) -> Self:
pass


class BaseResponseModel(BaseModel):
pass


@dataclass
class BaseResponse:
data: BaseModel
status_code: int
class ApiResponse:
_status_code: int
_data: Optional[BaseResponseModel]

@classmethod
def build(cls, status_code: int, response_model: BaseResponseModel) -> Self:
return cls(_status_code=status_code, _data=response_model)

@classmethod
def no_content(cls, status_code: int):
return cls(_status_code=status_code, _data=None)


_ParamType = BodyParam | QueryParam | PathParam | HeaderParam | MiddlewareParam
Expand Down Expand Up @@ -201,12 +213,15 @@ async def _pydantic_handler(request: web.Request, handler, signature) -> web.Res

response = await handler(**handler_params.get_all())

if not isinstance(response, BaseResponse):
if not isinstance(response, ApiResponse):
raise InvalidAPIParameters(
f"Only Response wrapped by BaseResponse Class can be handle: {type(response)}"
f"Only Response wrapped by ApiResponse Class can be handle: {type(response)}"
)

return web.json_response(response.data.model_dump(mode="json"), status=response.status_code)
return web.json_response(
response._data.model_dump(mode="json") if response._data else {},
status=response._status_code,
)


def pydantic_api_handler(handler):
Expand All @@ -217,25 +232,26 @@ def pydantic_api_handler(handler):
@pydantic_api_handler
async def handler(body: BodyParam[UserModel]): # UserModel is a Pydantic model
user = body.parsed # 'parsed' property gets pydantic model you defined
return BaseResponse(status_code=200, data=YourResponseModel(user=user.id))
# Response model should inherit BaseResponseModel
return ApiResponse.build(status_code=200, response_model=YourResponseModel(user=user.id))
2. Query Parameters:
@pydantic_api_handler
async def handler(query: QueryParam[QueryPathModel]):
parsed_query = query.parsed
return BaseResponse(status_code=200, data=YourResponseModel(search=parsed_query.query))
return ApiResponse.build(status_code=200, response_model=YourResponseModel(search=parsed_query.query))
3. Headers:
@pydantic_api_handler
async def handler(headers: HeaderParam[HeaderModel]):
parsed_header = headers.parsed
return BaseResponse(status_code=200, data=YourResponseModel(data=parsed_header.token))
return ApiResponse.build(status_code=200, response_model=YourResponseModel(data=parsed_header.token))
4. Path Parameters:
@pydantic_api_handler
async def handler(path: PathModel = PathParam(PathModel)):
parsed_path = path.parsed
return BaseResponse(status_code=200, data=YourResponseModel(path=parsed_path))
return ApiResponse.build(status_code=200, response_model=YourResponseModel(path=parsed_path))
5. Middleware Parameters:
# Need to extend MiddlewareParam and implement 'from_request'
Expand All @@ -251,7 +267,7 @@ def from_request(cls, request: web.Request) -> Self:
@pydantic_api_handler
async def handler(auth: AuthMiddlewareParam): # No generic, so no need to call 'parsed'
return BaseResponse(status_code=200, data=YourResponseModel(author_name=auth.name))
return ApiResponse(status_code=200, response_model=YourResponseModel(author_name=auth.name))
6. Multiple Parameters:
@pydantic_api_handler
Expand All @@ -261,9 +277,9 @@ async def handler(
headers: HeaderParam[HeaderModel], # headers
auth: AuthMiddleware, # middleware parameter
):
return BaseResponse(
return ApiResponse(
status_code=200,
data=YourResponseModel(
response_model=YourResponseModel(
user=user.parsed.user_id,
query=query.parsed.page,
headers=headers.parsed.auth,
Expand All @@ -273,15 +289,15 @@ async def handler(
Note:
- All parameters must have type hints or wrapped by Annotated
- Response class must be BaseResponse. put your response model in BaseResponse.data
- Response class must be ApiResponse and your response model should inherit BaseResponseModel
- Request body is parsed must be json format
- MiddlewareParam classes must implement the from_request classmethod
"""

original_signature = inspect.signature(handler)

@functools.wraps(handler)
async def wrapped(request: web.Request, *args, **kwargs) -> web.Response:
async def wrapped(request: Any, *args, **kwargs) -> web.Response:
if isinstance(request, web.Request):
return await _pydantic_handler(request, handler, original_signature)

Expand Down
105 changes: 59 additions & 46 deletions tests/common/test_pydantic_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pydantic import BaseModel, Field

from ai.backend.common.pydantic_handlers import (
BaseResponse,
ApiResponse,
BaseResponseModel,
BodyParam,
HeaderParam,
MiddlewareParam,
Expand All @@ -15,16 +16,17 @@
)


class TestEmptyResponseModel(BaseModel):
class TestEmptyResponseModel(BaseResponseModel):
status: str
version: str


class TestEmptyHandlerClass:
@pydantic_api_handler
async def handle_empty(self) -> BaseResponse:
return BaseResponse(
status_code=200, data=TestEmptyResponseModel(status="success", version="1.0.0")
async def handle_empty(self) -> ApiResponse:
return ApiResponse.build(
status_code=200,
response_model=TestEmptyResponseModel(status="success", version="1.0.0"),
)


Expand Down Expand Up @@ -55,7 +57,7 @@ class TestSearchParamsModel(BaseModel):
limit: Optional[int] = Field(default=10)


class TestCombinedResponseModel(BaseModel):
class TestCombinedResponseModel(BaseResponseModel):
user_info: dict
search_info: dict
timestamp: str
Expand All @@ -65,13 +67,13 @@ class CombinedParamsHandlerClass:
@pydantic_api_handler
async def handle_combined(
self, user: BodyParam[TestUserRequestModel], search: QueryParam[TestSearchParamsModel]
) -> BaseResponse:
) -> ApiResponse:
parsed_user = user.parsed
parsed_search = search.parsed

return BaseResponse(
return ApiResponse.build(
status_code=200,
data=TestCombinedResponseModel(
response_model=TestCombinedResponseModel(
user_info={
"username": parsed_user.username,
"email": parsed_user.email,
Expand Down Expand Up @@ -113,15 +115,17 @@ async def test_combined_parameters_handler_in_class(aiohttp_client):
assert data["search_info"]["limit"] == 20


class TestMessageResponse(BaseModel):
class TestMessageResponse(BaseResponseModel):
message: str


@pytest.mark.asyncio
async def test_empty_parameter(aiohttp_client):
@pydantic_api_handler
async def handler() -> BaseResponse:
return BaseResponse(status_code=200, data=TestMessageResponse(message="test"))
async def handler() -> ApiResponse:
return ApiResponse.build(
status_code=200, response_model=TestMessageResponse(message="test")
)

app = web.Application()
app.router.add_route("GET", "/test", handler)
Expand All @@ -140,18 +144,19 @@ class TestPostUserModel(BaseModel):
age: int


class TestPostUserResponse(BaseModel):
class TestPostUserResponse(BaseResponseModel):
name: str
age: int


@pytest.mark.asyncio
async def test_body_parameter(aiohttp_client):
@pydantic_api_handler
async def handler(user: BodyParam[TestPostUserModel]) -> BaseResponse:
async def handler(user: BodyParam[TestPostUserModel]) -> ApiResponse:
parsed_user = user.parsed
return BaseResponse(
status_code=200, data=TestPostUserResponse(name=parsed_user.name, age=parsed_user.age)
return ApiResponse.build(
status_code=200,
response_model=TestPostUserResponse(name=parsed_user.name, age=parsed_user.age),
)

app = web.Application()
Expand All @@ -173,19 +178,21 @@ class TestSearchQueryModel(BaseModel):
page: Optional[int] = Field(default=1)


class TestSearchQueryResponse(BaseModel):
class TestSearchQueryResponse(BaseResponseModel):
search: str
page: Optional[int] = Field(default=1)


@pytest.mark.asyncio
async def test_query_parameter(aiohttp_client):
@pydantic_api_handler
async def handler(query: QueryParam[TestSearchQueryModel]) -> BaseResponse:
async def handler(query: QueryParam[TestSearchQueryModel]) -> ApiResponse:
parsed_query = query.parsed
return BaseResponse(
return ApiResponse.build(
status_code=200,
data=TestSearchQueryResponse(search=parsed_query.search, page=parsed_query.page),
response_model=TestSearchQueryResponse(
search=parsed_query.search, page=parsed_query.page
),
)

app = web.Application()
Expand All @@ -204,17 +211,18 @@ class TestAuthHeaderModel(BaseModel):
authorization: str


class TestAuthHeaderResponse(BaseModel):
class TestAuthHeaderResponse(BaseResponseModel):
authorization: str


@pytest.mark.asyncio
async def test_header_parameter(aiohttp_client):
@pydantic_api_handler
async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> BaseResponse:
async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> ApiResponse:
parsed_headers = headers.parsed
return BaseResponse(
status_code=200, data=TestAuthHeaderResponse(authorization=parsed_headers.authorization)
return ApiResponse.build(
status_code=200,
response_model=TestAuthHeaderResponse(authorization=parsed_headers.authorization),
)

app = web.Application()
Expand All @@ -233,16 +241,18 @@ class TestUserPathModel(BaseModel):
user_id: str


class TestUserPathResponse(BaseModel):
class TestUserPathResponse(BaseResponseModel):
user_id: str


@pytest.mark.asyncio
async def test_path_parameter(aiohttp_client):
@pydantic_api_handler
async def handler(path: PathParam[TestUserPathModel]) -> BaseResponse:
async def handler(path: PathParam[TestUserPathModel]) -> ApiResponse:
parsed_path = path.parsed
return BaseResponse(status_code=200, data=TestUserPathResponse(user_id=parsed_path.user_id))
return ApiResponse.build(
status_code=200, response_model=TestUserPathResponse(user_id=parsed_path.user_id)
)

app = web.Application()
app.router.add_get("/test/{user_id}", handler)
Expand All @@ -263,16 +273,16 @@ def from_request(cls, request: web.Request) -> Self:
return cls(is_authorized=request.get("is_authorized", False))


class TestAuthResponse(BaseModel):
class TestAuthResponse(BaseResponseModel):
is_authorized: bool = Field(default=False)


@pytest.mark.asyncio
async def test_middleware_parameter(aiohttp_client):
@pydantic_api_handler
async def handler(auth: TestAuthInfo) -> BaseResponse:
return BaseResponse(
status_code=200, data=TestAuthResponse(is_authorized=auth.is_authorized)
async def handler(auth: TestAuthInfo) -> ApiResponse:
return ApiResponse.build(
status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized)
)

@web.middleware
Expand All @@ -295,9 +305,9 @@ async def auth_middleware(request, handler):
@pytest.mark.asyncio
async def test_middleware_parameter_invalid_type(aiohttp_client):
@pydantic_api_handler
async def handler(auth: TestAuthInfo) -> BaseResponse:
return BaseResponse(
status_code=200, data=TestAuthResponse(is_authorized=auth.is_authorized)
async def handler(auth: TestAuthInfo) -> ApiResponse:
return ApiResponse.build(
status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized)
)

@web.middleware
Expand Down Expand Up @@ -334,7 +344,7 @@ class TestSearchParamModel(BaseModel):
query: str


class TestCombinedResponse(BaseModel):
class TestCombinedResponse(BaseResponseModel):
user_name: str
query: str
is_authorized: bool
Expand All @@ -347,13 +357,13 @@ async def handler(
body: BodyParam[TestCreateUserModel],
auth: TestMiddlewareModel,
query: QueryParam[TestSearchParamModel],
) -> BaseResponse:
) -> ApiResponse:
parsed_body = body.parsed
parsed_query = query.parsed

return BaseResponse(
return ApiResponse.build(
status_code=200,
data=TestCombinedResponse(
response_model=TestCombinedResponse(
user_name=parsed_body.user_name,
query=parsed_query.query,
is_authorized=auth.is_authorized,
Expand Down Expand Up @@ -385,18 +395,19 @@ class TestRegisterUserModel(BaseModel):
age: int


class TestRegisterUserResponse(BaseModel):
class TestRegisterUserResponse(BaseResponseModel):
name: str
age: int


@pytest.mark.asyncio
async def test_invalid_body(aiohttp_client):
@pydantic_api_handler
async def handler(user: BodyParam[TestRegisterUserModel]) -> BaseResponse:
async def handler(user: BodyParam[TestRegisterUserModel]) -> ApiResponse:
test_user = user.parsed
return BaseResponse(
status_code=200, data=TestRegisterUserResponse(name=test_user.name, age=test_user.age)
return ApiResponse.build(
status_code=200,
response_model=TestRegisterUserResponse(name=test_user.name, age=test_user.age),
)

app = web.Application()
Expand All @@ -413,19 +424,21 @@ class TestProductSearchModel(BaseModel):
page: Optional[int] = Field(default=1)


class TestProductSearchResponse(BaseModel):
class TestProductSearchResponse(BaseResponseModel):
search: str
page: Optional[int] = Field(default=1)


@pytest.mark.asyncio
async def test_invalid_query_parameter(aiohttp_client):
@pydantic_api_handler
async def handler(query: QueryParam[TestProductSearchModel]) -> BaseResponse:
async def handler(query: QueryParam[TestProductSearchModel]) -> ApiResponse:
parsed_query = query.parsed
return BaseResponse(
return ApiResponse.build(
status_code=200,
data=TestProductSearchResponse(search=parsed_query.search, page=parsed_query.page),
response_model=TestProductSearchResponse(
search=parsed_query.search, page=parsed_query.page
),
)

app = web.Application()
Expand Down

0 comments on commit 9519d29

Please sign in to comment.