diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index 515c6a491c..2ed279a53b 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -112,10 +112,22 @@ def from_request(cls, request: web.Request) -> Self: pass +class BaseResponseModel(BaseModel): + pass + + @dataclass -class BaseResponse: - data: BaseModel - status_code: int +class ApiResponse: + _status_code: int + _data: Optional[BaseResponseModel] + + @classmethod + def build(cls, status_code: int, response_model: BaseResponseModel) -> Self: + return cls(_status_code=status_code, _data=response_model) + + @classmethod + def no_content(cls, status_code: int): + return cls(_status_code=status_code, _data=None) _ParamType = BodyParam | QueryParam | PathParam | HeaderParam | MiddlewareParam @@ -201,12 +213,15 @@ async def _pydantic_handler(request: web.Request, handler, signature) -> web.Res response = await handler(**handler_params.get_all()) - if not isinstance(response, BaseResponse): + if not isinstance(response, ApiResponse): raise InvalidAPIParameters( - f"Only Response wrapped by BaseResponse Class can be handle: {type(response)}" + f"Only Response wrapped by ApiResponse Class can be handle: {type(response)}" ) - return web.json_response(response.data.model_dump(mode="json"), status=response.status_code) + return web.json_response( + response._data.model_dump(mode="json") if response._data else {}, + status=response._status_code, + ) def pydantic_api_handler(handler): @@ -217,25 +232,26 @@ def pydantic_api_handler(handler): @pydantic_api_handler async def handler(body: BodyParam[UserModel]): # UserModel is a Pydantic model user = body.parsed # 'parsed' property gets pydantic model you defined - return BaseResponse(status_code=200, data=YourResponseModel(user=user.id)) + # Response model should inherit BaseResponseModel + return ApiResponse.build(status_code=200, response_model=YourResponseModel(user=user.id)) 2. Query Parameters: @pydantic_api_handler async def handler(query: QueryParam[QueryPathModel]): parsed_query = query.parsed - return BaseResponse(status_code=200, data=YourResponseModel(search=parsed_query.query)) + return ApiResponse.build(status_code=200, response_model=YourResponseModel(search=parsed_query.query)) 3. Headers: @pydantic_api_handler async def handler(headers: HeaderParam[HeaderModel]): parsed_header = headers.parsed - return BaseResponse(status_code=200, data=YourResponseModel(data=parsed_header.token)) + return ApiResponse.build(status_code=200, response_model=YourResponseModel(data=parsed_header.token)) 4. Path Parameters: @pydantic_api_handler async def handler(path: PathModel = PathParam(PathModel)): parsed_path = path.parsed - return BaseResponse(status_code=200, data=YourResponseModel(path=parsed_path)) + return ApiResponse.build(status_code=200, response_model=YourResponseModel(path=parsed_path)) 5. Middleware Parameters: # Need to extend MiddlewareParam and implement 'from_request' @@ -251,7 +267,7 @@ def from_request(cls, request: web.Request) -> Self: @pydantic_api_handler async def handler(auth: AuthMiddlewareParam): # No generic, so no need to call 'parsed' - return BaseResponse(status_code=200, data=YourResponseModel(author_name=auth.name)) + return ApiResponse(status_code=200, response_model=YourResponseModel(author_name=auth.name)) 6. Multiple Parameters: @pydantic_api_handler @@ -261,9 +277,9 @@ async def handler( headers: HeaderParam[HeaderModel], # headers auth: AuthMiddleware, # middleware parameter ): - return BaseResponse( + return ApiResponse( status_code=200, - data=YourResponseModel( + response_model=YourResponseModel( user=user.parsed.user_id, query=query.parsed.page, headers=headers.parsed.auth, @@ -273,7 +289,7 @@ async def handler( Note: - All parameters must have type hints or wrapped by Annotated - - Response class must be BaseResponse. put your response model in BaseResponse.data + - Response class must be ApiResponse and your response model should inherit BaseResponseModel - Request body is parsed must be json format - MiddlewareParam classes must implement the from_request classmethod """ @@ -281,7 +297,7 @@ async def handler( original_signature = inspect.signature(handler) @functools.wraps(handler) - async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: + async def wrapped(request: Any, *args, **kwargs) -> web.Response: if isinstance(request, web.Request): return await _pydantic_handler(request, handler, original_signature) diff --git a/tests/common/test_pydantic_handlers.py b/tests/common/test_pydantic_handlers.py index b2280278fc..12c61e2dcb 100644 --- a/tests/common/test_pydantic_handlers.py +++ b/tests/common/test_pydantic_handlers.py @@ -5,7 +5,8 @@ from pydantic import BaseModel, Field from ai.backend.common.pydantic_handlers import ( - BaseResponse, + ApiResponse, + BaseResponseModel, BodyParam, HeaderParam, MiddlewareParam, @@ -15,16 +16,17 @@ ) -class TestEmptyResponseModel(BaseModel): +class TestEmptyResponseModel(BaseResponseModel): status: str version: str class TestEmptyHandlerClass: @pydantic_api_handler - async def handle_empty(self) -> BaseResponse: - return BaseResponse( - status_code=200, data=TestEmptyResponseModel(status="success", version="1.0.0") + async def handle_empty(self) -> ApiResponse: + return ApiResponse.build( + status_code=200, + response_model=TestEmptyResponseModel(status="success", version="1.0.0"), ) @@ -55,7 +57,7 @@ class TestSearchParamsModel(BaseModel): limit: Optional[int] = Field(default=10) -class TestCombinedResponseModel(BaseModel): +class TestCombinedResponseModel(BaseResponseModel): user_info: dict search_info: dict timestamp: str @@ -65,13 +67,13 @@ class CombinedParamsHandlerClass: @pydantic_api_handler async def handle_combined( self, user: BodyParam[TestUserRequestModel], search: QueryParam[TestSearchParamsModel] - ) -> BaseResponse: + ) -> ApiResponse: parsed_user = user.parsed parsed_search = search.parsed - return BaseResponse( + return ApiResponse.build( status_code=200, - data=TestCombinedResponseModel( + response_model=TestCombinedResponseModel( user_info={ "username": parsed_user.username, "email": parsed_user.email, @@ -113,15 +115,17 @@ async def test_combined_parameters_handler_in_class(aiohttp_client): assert data["search_info"]["limit"] == 20 -class TestMessageResponse(BaseModel): +class TestMessageResponse(BaseResponseModel): message: str @pytest.mark.asyncio async def test_empty_parameter(aiohttp_client): @pydantic_api_handler - async def handler() -> BaseResponse: - return BaseResponse(status_code=200, data=TestMessageResponse(message="test")) + async def handler() -> ApiResponse: + return ApiResponse.build( + status_code=200, response_model=TestMessageResponse(message="test") + ) app = web.Application() app.router.add_route("GET", "/test", handler) @@ -140,7 +144,7 @@ class TestPostUserModel(BaseModel): age: int -class TestPostUserResponse(BaseModel): +class TestPostUserResponse(BaseResponseModel): name: str age: int @@ -148,10 +152,11 @@ class TestPostUserResponse(BaseModel): @pytest.mark.asyncio async def test_body_parameter(aiohttp_client): @pydantic_api_handler - async def handler(user: BodyParam[TestPostUserModel]) -> BaseResponse: + async def handler(user: BodyParam[TestPostUserModel]) -> ApiResponse: parsed_user = user.parsed - return BaseResponse( - status_code=200, data=TestPostUserResponse(name=parsed_user.name, age=parsed_user.age) + return ApiResponse.build( + status_code=200, + response_model=TestPostUserResponse(name=parsed_user.name, age=parsed_user.age), ) app = web.Application() @@ -173,7 +178,7 @@ class TestSearchQueryModel(BaseModel): page: Optional[int] = Field(default=1) -class TestSearchQueryResponse(BaseModel): +class TestSearchQueryResponse(BaseResponseModel): search: str page: Optional[int] = Field(default=1) @@ -181,11 +186,13 @@ class TestSearchQueryResponse(BaseModel): @pytest.mark.asyncio async def test_query_parameter(aiohttp_client): @pydantic_api_handler - async def handler(query: QueryParam[TestSearchQueryModel]) -> BaseResponse: + async def handler(query: QueryParam[TestSearchQueryModel]) -> ApiResponse: parsed_query = query.parsed - return BaseResponse( + return ApiResponse.build( status_code=200, - data=TestSearchQueryResponse(search=parsed_query.search, page=parsed_query.page), + response_model=TestSearchQueryResponse( + search=parsed_query.search, page=parsed_query.page + ), ) app = web.Application() @@ -204,17 +211,18 @@ class TestAuthHeaderModel(BaseModel): authorization: str -class TestAuthHeaderResponse(BaseModel): +class TestAuthHeaderResponse(BaseResponseModel): authorization: str @pytest.mark.asyncio async def test_header_parameter(aiohttp_client): @pydantic_api_handler - async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> BaseResponse: + async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> ApiResponse: parsed_headers = headers.parsed - return BaseResponse( - status_code=200, data=TestAuthHeaderResponse(authorization=parsed_headers.authorization) + return ApiResponse.build( + status_code=200, + response_model=TestAuthHeaderResponse(authorization=parsed_headers.authorization), ) app = web.Application() @@ -233,16 +241,18 @@ class TestUserPathModel(BaseModel): user_id: str -class TestUserPathResponse(BaseModel): +class TestUserPathResponse(BaseResponseModel): user_id: str @pytest.mark.asyncio async def test_path_parameter(aiohttp_client): @pydantic_api_handler - async def handler(path: PathParam[TestUserPathModel]) -> BaseResponse: + async def handler(path: PathParam[TestUserPathModel]) -> ApiResponse: parsed_path = path.parsed - return BaseResponse(status_code=200, data=TestUserPathResponse(user_id=parsed_path.user_id)) + return ApiResponse.build( + status_code=200, response_model=TestUserPathResponse(user_id=parsed_path.user_id) + ) app = web.Application() app.router.add_get("/test/{user_id}", handler) @@ -263,16 +273,16 @@ def from_request(cls, request: web.Request) -> Self: return cls(is_authorized=request.get("is_authorized", False)) -class TestAuthResponse(BaseModel): +class TestAuthResponse(BaseResponseModel): is_authorized: bool = Field(default=False) @pytest.mark.asyncio async def test_middleware_parameter(aiohttp_client): @pydantic_api_handler - async def handler(auth: TestAuthInfo) -> BaseResponse: - return BaseResponse( - status_code=200, data=TestAuthResponse(is_authorized=auth.is_authorized) + async def handler(auth: TestAuthInfo) -> ApiResponse: + return ApiResponse.build( + status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized) ) @web.middleware @@ -295,9 +305,9 @@ async def auth_middleware(request, handler): @pytest.mark.asyncio async def test_middleware_parameter_invalid_type(aiohttp_client): @pydantic_api_handler - async def handler(auth: TestAuthInfo) -> BaseResponse: - return BaseResponse( - status_code=200, data=TestAuthResponse(is_authorized=auth.is_authorized) + async def handler(auth: TestAuthInfo) -> ApiResponse: + return ApiResponse.build( + status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized) ) @web.middleware @@ -334,7 +344,7 @@ class TestSearchParamModel(BaseModel): query: str -class TestCombinedResponse(BaseModel): +class TestCombinedResponse(BaseResponseModel): user_name: str query: str is_authorized: bool @@ -347,13 +357,13 @@ async def handler( body: BodyParam[TestCreateUserModel], auth: TestMiddlewareModel, query: QueryParam[TestSearchParamModel], - ) -> BaseResponse: + ) -> ApiResponse: parsed_body = body.parsed parsed_query = query.parsed - return BaseResponse( + return ApiResponse.build( status_code=200, - data=TestCombinedResponse( + response_model=TestCombinedResponse( user_name=parsed_body.user_name, query=parsed_query.query, is_authorized=auth.is_authorized, @@ -385,7 +395,7 @@ class TestRegisterUserModel(BaseModel): age: int -class TestRegisterUserResponse(BaseModel): +class TestRegisterUserResponse(BaseResponseModel): name: str age: int @@ -393,10 +403,11 @@ class TestRegisterUserResponse(BaseModel): @pytest.mark.asyncio async def test_invalid_body(aiohttp_client): @pydantic_api_handler - async def handler(user: BodyParam[TestRegisterUserModel]) -> BaseResponse: + async def handler(user: BodyParam[TestRegisterUserModel]) -> ApiResponse: test_user = user.parsed - return BaseResponse( - status_code=200, data=TestRegisterUserResponse(name=test_user.name, age=test_user.age) + return ApiResponse.build( + status_code=200, + response_model=TestRegisterUserResponse(name=test_user.name, age=test_user.age), ) app = web.Application() @@ -413,7 +424,7 @@ class TestProductSearchModel(BaseModel): page: Optional[int] = Field(default=1) -class TestProductSearchResponse(BaseModel): +class TestProductSearchResponse(BaseResponseModel): search: str page: Optional[int] = Field(default=1) @@ -421,11 +432,13 @@ class TestProductSearchResponse(BaseModel): @pytest.mark.asyncio async def test_invalid_query_parameter(aiohttp_client): @pydantic_api_handler - async def handler(query: QueryParam[TestProductSearchModel]) -> BaseResponse: + async def handler(query: QueryParam[TestProductSearchModel]) -> ApiResponse: parsed_query = query.parsed - return BaseResponse( + return ApiResponse.build( status_code=200, - data=TestProductSearchResponse(search=parsed_query.search, page=parsed_query.page), + response_model=TestProductSearchResponse( + search=parsed_query.search, page=parsed_query.page + ), ) app = web.Application()