From 905067897003b6ac1c5dd190deaf8f0ae9beb33e Mon Sep 17 00:00:00 2001 From: Joway Date: Sun, 30 Sep 2018 14:16:54 +0800 Subject: [PATCH] bugfix : router function repeat exec (#24) --- lemon/app.py | 84 +++++++--------------- lemon/config.py | 1 + lemon/context.py | 4 +- lemon/exception.py | 58 +++++++++------ lemon/middleware.py | 6 +- lemon/parsers.py | 14 ++-- lemon/response.py | 1 - lemon/router.py | 150 +++++++++++++++++---------------------- requirements.txt | 2 +- setup.py | 4 +- tests/test_cors.py | 3 + tests/test_exception.py | 20 +++--- tests/test_middleware.py | 2 +- tests/test_parser.py | 6 +- tests/test_router.py | 119 +++++++++++++++++++++---------- 15 files changed, 244 insertions(+), 230 deletions(-) diff --git a/lemon/app.py b/lemon/app.py index a0d6f4a..27da287 100644 --- a/lemon/app.py +++ b/lemon/app.py @@ -1,6 +1,4 @@ -import json import logging.config -import traceback import typing from asyncio import get_event_loop from functools import partial @@ -8,9 +6,8 @@ from lemon.asgi import ASGIRequest from lemon.config import settings -from lemon.const import MIME_TYPES from lemon.context import Context -from lemon.exception import MiddlewareParamsError +from lemon.exception import LemonMiddlewareParamsError from lemon.log import LOGGING_CONFIG_DEFAULTS, logger from lemon.middleware import exception_middleware, cors_middleware from lemon.server import serve @@ -27,37 +24,27 @@ async def exec_middleware(ctx: Context, middleware_list: list, pos: int = 0): :param ctx: Context instance :param middleware_list: middleware registered on app - :param pos: position of the middleware in list + :param pos: current pos in middleware_list """ if pos >= len(middleware_list): return middleware = middleware_list[pos] logger.debug( - 'The No.{0} middleware : {1} started'.format( - pos, - middleware.__name__, - ) + 'middleware : %s started', + middleware.__name__, ) - try: - middleware_params = signature(middleware).parameters - if len(middleware_params) == 1: - return await middleware(ctx=ctx) - elif len(middleware_params) == 2: - return await middleware( - ctx=ctx, - nxt=partial(exec_middleware, ctx, middleware_list, pos + 1), - ) - else: - raise MiddlewareParamsError - finally: - logger.debug( - 'The No.{0} middleware : {1} finished'.format( - pos, - middleware.__name__, - ) + middleware_params = signature(middleware).parameters + if len(middleware_params) == 1: + await middleware(ctx=ctx) + elif len(middleware_params) == 2: + return await middleware( + ctx=ctx, + nxt=partial(exec_middleware, ctx, middleware_list, pos + 1), ) + else: + raise LemonMiddlewareParamsError class Lemon: @@ -66,8 +53,7 @@ def __init__(self, config: dict = None, debug=False) -> None: :param config: app config :param debug: if debug == True , set log level to DEBUG , else is INFO """ - self.config = config - settings.set_config(config=config) + settings.set_config(config=config or {}) self.middleware_list: list = [] @@ -103,35 +89,19 @@ async def _call(receive: typing.Callable, send: typing.Callable): + self.middleware_list \ + self.post_process_middleware_list - try: - await exec_middleware( - ctx=ctx, middleware_list=middleware_chain - ) - except Exception as e: - traceback.print_exc() - await send({ - 'type': 'http.response.start', - 'status': 500, - 'headers': [ - ['content-type', MIME_TYPES.APPLICATION_JSON, ] - ], - }) - await send({ - 'type': 'http.response.body', - 'body': json.dumps({ - 'lemon': 'Internal Error', - }).encode(), - }) - else: - await send({ - 'type': 'http.response.start', - 'status': ctx.res.status, - 'headers': ctx.res.raw_headers, - }) - await send({ - 'type': 'http.response.body', - 'body': ctx.res.raw_body, - }) + await exec_middleware( + ctx=ctx, middleware_list=middleware_chain, + ) + + await send({ + 'type': 'http.response.start', + 'status': ctx.res.status, + 'headers': ctx.res.raw_headers, + }) + await send({ + 'type': 'http.response.body', + 'body': ctx.res.raw_body, + }) return _call diff --git a/lemon/config.py b/lemon/config.py index 1cc9c56..dd51422 100644 --- a/lemon/config.py +++ b/lemon/config.py @@ -35,6 +35,7 @@ def __getitem__(self, key: str): # SERVER 'LEMON_SERVER_HOST': '127.0.0.1', 'LEMON_SERVER_PORT': '9999', + 'LEMON_DEBUG': False, # ROUTER 'LEMON_ROUTER_SLASH_SENSITIVE': False, # CORS diff --git a/lemon/context.py b/lemon/context.py index 767de9b..87d49f8 100644 --- a/lemon/context.py +++ b/lemon/context.py @@ -21,7 +21,7 @@ def __setattr__(self, key, value) -> None: # alias if key == 'body': self.res.body = value - if key == 'status': + elif key == 'status': self.res.status = value else: self.__dict__[key] = value @@ -30,7 +30,7 @@ def __getattr__(self, item) -> typing.Any: # alias if item == 'body': return self.res.body - if item == 'status': + elif item == 'status': return self.res.status return self.__dict__[item] diff --git a/lemon/exception.py b/lemon/exception.py index 71cc397..96f6685 100644 --- a/lemon/exception.py +++ b/lemon/exception.py @@ -1,4 +1,3 @@ -# ========== GeneralException ========== import typing @@ -8,43 +7,62 @@ def __init__(self, status=None, body: typing.Union[str, dict] = None) -> None: self.body = body -# ========== RuntimeError - 500 ========== -class ServerError(GeneralException): - def __init__(self, *args, **kwargs) -> None: - super(ServerError, self).__init__(*args, **kwargs) - self.status = 500 +# ========== RequestException 4xx ========== +class RequestBadError(GeneralException): + def __init__(self): + super().__init__(status=400, body={ + 'error': 'bad request' + }) -class MiddlewareParamsError(ServerError): - pass +class RequestUnauthorizedError(GeneralException): + def __init__(self): + super().__init__(status=401, body={ + 'error': 'unauthorized' + }) -class RouterRegisterError(ServerError): - pass +class RequestForbiddenError(GeneralException): + def __init__(self): + super().__init__(status=403, body={ + 'error': 'not found' + }) -class RouterMatchError(ServerError): +class RequestNotFoundError(GeneralException): + def __init__(self): + super().__init__(status=404, body={ + 'error': 'not found' + }) + + +class RequestHeadersParserError(RequestBadError): pass -class ResponseFormatError(ServerError): +class RequestBodyParserError(RequestBadError): pass -class LemonConfigKeyError(ServerError): +# ========== RuntimeError - 5xx ========== +class ServerError(GeneralException): + def __init__(self): + super().__init__(status=500, body={ + 'error': 'internal error' + }) + + +class LemonMiddlewareParamsError(ServerError): pass -# ========== BadRequestError - 400 ========== -class BadRequestError(GeneralException): - def __init__(self, *args, **kwargs) -> None: - super(BadRequestError, self).__init__(*args, **kwargs) - self.status = 400 +class LemonRouterRegisterError(ServerError): + pass -class RequestHeadersParserError(BadRequestError): +class LemonRouterMatchError(ServerError): pass -class RequestBodyParserError(BadRequestError): +class LemonConfigKeyError(ServerError): pass diff --git a/lemon/middleware.py b/lemon/middleware.py index 75486a3..5c90537 100644 --- a/lemon/middleware.py +++ b/lemon/middleware.py @@ -14,12 +14,14 @@ async def exception_middleware(ctx: Context, nxt: typing.Callable) -> typing.Any except GeneralException as e: ctx.body = e.body ctx.status = e.status + if settings.LEMON_DEBUG: + traceback.print_exc() except Exception as e: - traceback.print_exc() ctx.status = 500 ctx.body = ctx.body or { - 'lemon': 'INTERNAL ERROR', + 'error': 'unknown error', } + traceback.print_exc() async def cors_middleware(ctx: Context, nxt: typing.Callable): diff --git a/lemon/parsers.py b/lemon/parsers.py index 4741fef..80db29b 100644 --- a/lemon/parsers.py +++ b/lemon/parsers.py @@ -7,8 +7,8 @@ from werkzeug.http import parse_options_header from werkzeug.urls import url_decode -import lemon.exception as exception from lemon.const import MIME_TYPES +from lemon.exception import RequestHeadersParserError, RequestBodyParserError def get_mimetype_and_options(headers: dict) -> typing.Tuple[str, dict]: @@ -20,27 +20,23 @@ def get_mimetype_and_options(headers: dict) -> typing.Tuple[str, dict]: def get_content_length(headers: dict) -> typing.Optional[int]: if headers is None: - raise exception.RequestHeadersParserError + raise RequestHeadersParserError content_length = headers.get('content-length') if content_length is None: return None try: return max(0, int(content_length)) except (ValueError, TypeError): - raise exception.RequestHeadersParserError + raise RequestHeadersParserError def json_parser(body: bytes, *args) -> ImmutableMultiDict: if not body: - raise exception.BadRequestError(body={ - 'error': 'Empty Body', - }) + raise RequestBodyParserError try: return ImmutableMultiDict(json.loads(body.decode('utf-8'))) except json.JSONDecodeError: - raise exception.BadRequestError(body={ - 'error': 'Invalid JSON', - }) + raise RequestBodyParserError def url_encoded_parser(body: bytes, *args) -> dict: diff --git a/lemon/response.py b/lemon/response.py index caa91aa..661617d 100644 --- a/lemon/response.py +++ b/lemon/response.py @@ -1,7 +1,6 @@ import json from lemon.const import MIME_TYPES, CHARSETS -from lemon.exception import ResponseFormatError from lemon.request import HttpHeaders diff --git a/lemon/router.py b/lemon/router.py index 40d6df0..7237717 100644 --- a/lemon/router.py +++ b/lemon/router.py @@ -1,12 +1,12 @@ import typing from abc import ABCMeta, abstractmethod -from inspect import signature import kua +from lemon.app import exec_middleware from lemon.config import settings from lemon.const import HTTP_METHODS -from lemon.exception import RouterRegisterError, RouterMatchError +from lemon.exception import LemonRouterRegisterError, RequestNotFoundError _HTTP_METHODS = [ HTTP_METHODS.GET, @@ -17,6 +17,12 @@ ] +def _clean_slash(path: str): + if path and path[-1] == '/': + path = path[:-1] + return path + + class AbstractRouter(metaclass=ABCMeta): @abstractmethod def use(self, methods: list, path: str, *middleware_list) -> None: @@ -33,6 +39,12 @@ def routes(self) -> typing.Callable: """ raise NotImplementedError + @abstractmethod + def match(self, ctx) -> typing.List: + """Return route + """ + raise NotImplementedError + class AbstractBaseRouter(AbstractRouter, metaclass=ABCMeta): def get(self, path: str, *middleware_list) -> None: @@ -75,8 +87,27 @@ def all(self, path: str, *middleware_list) -> None: HTTP_METHODS.DELETE, ], path, *middleware_list) + def routes(self) -> typing.Callable: + """Generate async router function(ctx, nxt) + """ + + async def _routes(ctx, nxt=None) -> None: + method = ctx.req.method + path = ctx.req.path + middleware_list = self.match(ctx=ctx) + + if len(middleware_list) == 0: + raise RequestNotFoundError -class SimpleRouter(AbstractBaseRouter): + await exec_middleware(ctx, middleware_list) + + if nxt: + await nxt() + + return _routes + + +class SimpleRouter(AbstractBaseRouter, metaclass=ABCMeta): def __init__(self, slash=settings.LEMON_ROUTER_SLASH_SENSITIVE) -> None: self.slash = slash self._routes: dict = { @@ -86,6 +117,18 @@ def __init__(self, slash=settings.LEMON_ROUTER_SLASH_SENSITIVE) -> None: HTTP_METHODS.DELETE: {}, } + def match(self, ctx) -> typing.List: + method = ctx.req.method + path = ctx.req.path + + if not self.slash: + path = _clean_slash(path) + + if path not in self._routes[method]: + raise RequestNotFoundError + + return self._routes[method][path] + def use(self, methods: list, path: str, *middleware_list) -> None: """Register routes :param methods: GET|PUT|POST|DELETE @@ -94,40 +137,10 @@ def use(self, methods: list, path: str, *middleware_list) -> None: """ for method in methods: if method not in _HTTP_METHODS: - raise RouterRegisterError( - 'Cannot support method : {0}'.format(method) - ) - if not self.slash and path[-1] == '/': - path = path[:-1] - self._routes[method][path] = middleware_list - - def routes(self) -> typing.Callable: - """Generate async router function(ctx, nxt) - """ - - async def _routes(ctx, nxt) -> None: - method = ctx.req.method - path = ctx.req.path - - if not self.slash and path[-1] == '/': - path = path[:-1] - - if path not in self._routes[method]: - ctx.status = 404 - ctx.body = { - 'lemon': 'NOT FOUND' - } - return - - middleware_list = self._routes[method][path] - for middleware in middleware_list: - middleware_params = signature(middleware).parameters - if len(middleware_params) == 1: - await middleware(ctx) - else: - await middleware(ctx, nxt) - - return _routes + raise LemonRouterRegisterError + if not self.slash: + path = _clean_slash(path) + self._routes[method][path] = list(middleware_list) class Router(AbstractBaseRouter): @@ -148,61 +161,32 @@ def use(self, methods: list, path: str, *middleware_list) -> None: """ for method in methods: if method not in _HTTP_METHODS: - raise RouterRegisterError( - 'Cannot support method : {0}'.format(method) - ) + raise LemonRouterRegisterError self._register_middleware_list(method, path, *middleware_list) - def routes(self) -> typing.Callable: - """Generate async router function(ctx, nxt) - """ + def match(self, ctx) -> typing.Any: + method = ctx.req.method + path = ctx.req.path - async def _routes(ctx, nxt) -> None: - method = ctx.req.method - path = ctx.req.path - route = self._match_middleware_list(method=method, path=path) - - if route is None: - ctx.status = 404 - ctx.body = { - 'lemon': 'NOT FOUND' - } - return + if not self.slash: + path = _clean_slash(path) + if method not in _HTTP_METHODS: + raise RequestNotFoundError + try: + route = self._routes[method].match(path) ctx.params = route.params - for middleware in route.anything: - middleware_params = signature(middleware).parameters - if len(middleware_params) == 1: - await middleware(ctx) - else: - await middleware(ctx, nxt) - - return _routes + return route.anything + except kua.RouteError: + raise RequestNotFoundError def _register_middleware_list( self, method: str, path: str, *middleware_list ) -> None: - if not self.slash and path[-1] == '/': - path = path[:-1] + if not self.slash: + path = _clean_slash(path) if method not in _HTTP_METHODS: - raise RouterMatchError( - 'Method {0} is not supported'.format(method) - ) - - return self._routes[method].add(path, middleware_list) - - def _match_middleware_list(self, method: str, path: str) -> typing.Any: - if not self.slash and path[-1] == '/': - path = path[:-1] + raise LemonRouterRegisterError - if method not in _HTTP_METHODS: - raise RouterMatchError( - 'Method {0} is not supported'.format(method) - ) - try: - return self._routes[method].match(path) - except kua.RouteError: - return None - except KeyError: - return None + return self._routes[method].add(path, list(middleware_list)) diff --git a/requirements.txt b/requirements.txt index e5e657d..b42259a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # production -uvicorn>=0.1.1 +uvicorn>=0.3.6 kua>=0.2 werkzeug>=0.14.1 diff --git a/setup.py b/setup.py index 3dff25c..9948aaa 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,8 @@ from setuptools import setup -PACKAGE_VERSION = '0.2.0' +PACKAGE_VERSION = '0.2.2' PACKAGE_REQUIRES = [ - 'uvicorn==0.1.1', + 'uvicorn==0.3.6', 'kua==0.2', 'werkzeug==0.14.1', ] diff --git a/tests/test_cors.py b/tests/test_cors.py index 8390bcd..f7e03ea 100644 --- a/tests/test_cors.py +++ b/tests/test_cors.py @@ -152,6 +152,7 @@ async def handle(ctx: Context): assert req.headers['access-control-allow-headers'] == 'allow_header' assert req.headers['access-control-allow-credentials'] == 'true' assert req.headers['access-control-max-age'] == '8640' + assert req.status == 204 req = await self.asgi_request( app, @@ -164,6 +165,7 @@ async def handle(ctx: Context): assert req.headers['access-control-allow-origin'] == 'http://a.com' assert req.headers['access-control-allow-credentials'] == 'true' assert req.headers['access-control-expose-headers'] == 'test_header' + assert req.status == 200 async def test_cors_not_allowed_request(self): async def handle(ctx: Context): @@ -190,3 +192,4 @@ async def handle(ctx: Context): ] ) assert 'access-control-allow-origin' not in req.headers + assert req.status == 200 diff --git a/tests/test_exception.py b/tests/test_exception.py index 16e2e0f..d4cf8dd 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -1,7 +1,7 @@ from lemon.exception import ( - MiddlewareParamsError, - RouterRegisterError, - RouterMatchError, + LemonMiddlewareParamsError, + LemonRouterRegisterError, + LemonRouterMatchError, RequestBodyParserError, GeneralException ) @@ -18,23 +18,23 @@ def test_exception(self): assert e.body == 'err' try: - raise MiddlewareParamsError + raise LemonMiddlewareParamsError except GeneralException as e: assert e.status == 500 try: - raise MiddlewareParamsError - except MiddlewareParamsError as e: + raise LemonMiddlewareParamsError + except LemonMiddlewareParamsError as e: assert e.status == 500 try: - raise RouterRegisterError - except RouterRegisterError as e: + raise LemonRouterRegisterError + except LemonRouterRegisterError as e: assert e.status == 500 try: - raise RouterMatchError - except RouterMatchError as e: + raise LemonRouterMatchError + except LemonRouterMatchError as e: assert e.status == 500 try: diff --git a/tests/test_middleware.py b/tests/test_middleware.py index e4e1f49..6ee65c0 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -14,4 +14,4 @@ async def handle(ctx: Context): req = await self.get('/') data = req.json() assert req.status_code == 500 - assert data['lemon'] == 'INTERNAL ERROR' + assert data['error'] == 'unknown error' diff --git a/tests/test_parser.py b/tests/test_parser.py index e807219..9c889b2 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,4 +1,4 @@ -from lemon.exception import RequestHeadersParserError, BadRequestError +from lemon.exception import RequestHeadersParserError, RequestBadError from lemon.parsers import get_content_length, json_parser, url_encoded_parser from tests import BasicHttpTestCase @@ -23,13 +23,13 @@ def test_json_parser(self): try: json_parser(None) assert False - except BadRequestError: + except RequestBadError: pass try: json_parser(b'{') assert False - except BadRequestError: + except RequestBadError: pass def test_url_encoded_parser(self): diff --git a/tests/test_router.py b/tests/test_router.py index c512e08..2ed16d3 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -73,9 +73,9 @@ async def handler2(ctx): req = await self.get('/app/xxx/') data = req.json() assert req.status_code == 404 - assert data['lemon'] == 'NOT FOUND' + assert data['error'] == 'not found' - async def test_router_example(self): + async def test_router(self): async def middleware(ctx, nxt): ctx.body = {'m': 'mid'} await nxt() @@ -112,51 +112,92 @@ async def handler2(ctx): req = await self.get('/app/xxx/') data = req.json() assert req.status_code == 404 - assert data['lemon'] == 'NOT FOUND' + assert data['error'] == 'not found' - async def test_router_register(self): - async def middleware(ctx, nxt): - ctx.body = {'m': 2} - await nxt - - async def handler(ctx): - ctx.body = {'x': 1} + async def test_router_restful(self): + async def handler1(ctx): + assert ctx.params['id'] == 'idxxx' + assert ctx.params['username'] == 'namexxx' + ctx.body = { + 'msg': 'ok' + } router = Router() - router._register_middleware_list( - 'GET', '/res/action', middleware, handler - ) + router.get('/app/:id/:username', handler1) - route = router._match_middleware_list('GET', '/res') - assert route is None - - route = router._match_middleware_list('GET', '/res/action/') - assert route is not None - assert len(route.anything) == 2 + self.app.use(router.routes()) + req = await self.get('/app/idxxx/namexxx') + data = req.json() + assert req.status_code == 200 + assert data['msg'] == 'ok' + + async def test_router_exec_order(self): + global before_count + global after_count + global handler_count + global orders + + before_count = 0 + after_count = 0 + handler_count = 0 + orders = [] + + async def before(ctx, nxt): + ctx.body = {'count': 1} + global before_count + global orders + orders.append(0) + before_count += 1 + await nxt() - route = router._match_middleware_list('GET', '/res/action') - assert route is not None - assert len(route.anything) == 2 + async def handler1(ctx, nxt): + ctx.body['count'] += 1 + global handler_count + global orders + orders.append(1) + handler_count += 1 + await nxt() - async def test_rest_router_register(self): - async def handler(ctx): - ctx.body = {'x': 1} + async def handler2(ctx, nxt): + ctx.body['count'] += 1 + global handler_count + global orders + orders.append(2) + handler_count += 1 + await nxt() - router = Router() - router._register_middleware_list('GET', '/res/:id/action', handler) + async def handler3(ctx, nxt): + ctx.body['count'] += 1 + global handler_count + global orders + orders.append(3) + handler_count += 1 + await nxt() - route = router._match_middleware_list('GET', '/res/xxx/action') - assert route is not None - assert route.params['id'] == 'xxx' - assert len(route.anything) == 1 + async def after(ctx): + global after_count + global orders + orders.append(4) + after_count += 1 + ctx.body['count'] += 1 - route = router._match_middleware_list('GET', '/res/:id/action') - assert route is not None - assert route.params['id'] == ':id' - assert len(route.anything) == 1 + router = Router() + router.get('/', handler1, handler2, handler3) - route = router._match_middleware_list('GET', '/re/:id/action') - assert route is None + self.app.use(before) + self.app.use(router.routes()) + self.app.use(after) - route = router._match_middleware_list('GET', '/res/:id/actions') - assert route is None + req = await self.get('/') + data = req.json() + assert req.status_code == 200 + assert data['count'] == 5 + assert before_count == 1 + assert after_count == 1 + assert handler_count == 3 + + assert orders[0] == 0 + assert orders[1] == 1 + assert orders[2] == 2 + assert orders[3] == 3 + assert orders[4] == 4