diff --git a/.codeclimate.yml b/.codeclimate.yml index 506005acbe..08c10e26c2 100644 --- a/.codeclimate.yml +++ b/.codeclimate.yml @@ -10,3 +10,15 @@ exclude_patterns: - "examples/" - "hack/" - "scripts/" + - "tests/" +checks: + argument-count: + enabled: false + file-lines: + config: + threshold: 1000 + method-count: + config: + threshold: 40 + complex-logic: + enabled: false diff --git a/.github/workflows/pr-windows.yml b/.github/workflows/pr-windows.yml index a47a500628..e3a32e5d87 100644 --- a/.github/workflows/pr-windows.yml +++ b/.github/workflows/pr-windows.yml @@ -1,34 +1,34 @@ -# name: Run Unit Tests on Windows -# on: -# pull_request: -# branches: -# - main +name: Run Unit Tests on Windows +on: + pull_request: + branches: + - main -# jobs: -# testsOnWindows: -# name: ut-${{ matrix.config.tox-env }} -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# config: -# - { python-version: 3.7, tox-env: py37-no-ext } -# - { python-version: 3.8, tox-env: py38-no-ext } -# - { python-version: 3.9, tox-env: py39-no-ext } -# - { python-version: pypy-3.7, tox-env: pypy37-no-ext } +jobs: + testsOnWindows: + name: ut-${{ matrix.config.tox-env }} + runs-on: windows-latest + strategy: + fail-fast: false + matrix: + config: + - { python-version: 3.7, tox-env: py37-no-ext } + - { python-version: 3.8, tox-env: py38-no-ext } + - { python-version: 3.9, tox-env: py39-no-ext } + - { python-version: pypy-3.7, tox-env: pypy37-no-ext } -# steps: -# - name: Checkout Repository -# uses: actions/checkout@v2 + steps: + - name: Checkout Repository + uses: actions/checkout@v2 -# - name: Run Unit Tests -# uses: ahopkins/custom-actions@pip-extra-args -# with: -# python-version: ${{ matrix.config.python-version }} -# test-infra-tool: tox -# test-infra-version: latest -# action: tests -# test-additional-args: "-e=${{ matrix.config.tox-env }}" -# experimental-ignore-error: "true" -# command-timeout: "600000" -# pip-extra-args: "--user" + - name: Run Unit Tests + uses: ahopkins/custom-actions@pip-extra-args + with: + python-version: ${{ matrix.config.python-version }} + test-infra-tool: tox + test-infra-version: latest + action: tests + test-additional-args: "-e=${{ matrix.config.tox-env }}" + experimental-ignore-error: "true" + command-timeout: "600000" + pip-extra-args: "--user" diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5e99452b14..a9940da15f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,22 @@ +.. note:: + + From v21.9, CHANGELOG files are maintained in ``./docs/sanic/releases`` + +Version 21.6.1 +-------------- + +Bugfixes +******** + + * `#2178 `_ + Update sanic-routing to allow for better splitting of complex URI templates + * `#2183 `_ + Proper handling of chunked request bodies to resolve phantom 503 in logs + * `#2181 `_ + Resolve regression in exception logging + * `#2201 `_ + Cleanup request info in pipelined requests + Version 21.6.0 -------------- diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index c87f2355bc..74dee22f6e 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -19,7 +19,7 @@ a virtual environment already set up, then run: .. code-block:: bash - pip3 install -e . ".[dev]" + pip install -e ".[dev]" Dependency Changes ------------------ diff --git a/README.rst b/README.rst index c3623bb87d..c6616f16c5 100644 --- a/README.rst +++ b/README.rst @@ -77,17 +77,7 @@ The goal of the project is to provide a simple way to get up and running a highl Sponsor ------- -|Try CodeStream| - -.. |Try CodeStream| image:: https://alt-images.codestream.com/codestream_logo_sanicorg.png - :target: https://codestream.com/?utm_source=github&utm_campaign=sanicorg&utm_medium=banner - :alt: Try CodeStream - -Manage pull requests and conduct code reviews in your IDE with full source-tree context. Comment on any line, not just the diffs. Use jump-to-definition, your favorite keybindings, and code intelligence with more of your workflow. - -`Learn More `_ - -Thank you to our sponsor. Check out `open collective `_ to learn more about helping to fund Sanic. +Check out `open collective `_ to learn more about helping to fund Sanic. Installation ------------ diff --git a/docs/conf.py b/docs/conf.py index 62f6ae4aa2..30a01e4c85 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,10 +10,8 @@ import os import sys -# Add support for auto-doc -import recommonmark -from recommonmark.transform import AutoStructify +# Add support for auto-doc # Ensure that sanic is present in the path, to allow sphinx-apidoc to @@ -26,7 +24,7 @@ # -- General configuration ------------------------------------------------ -extensions = ["sphinx.ext.autodoc", "recommonmark"] +extensions = ["sphinx.ext.autodoc", "m2r2"] templates_path = ["_templates"] @@ -162,20 +160,6 @@ "member-order": "groupwise", } - -# app setup hook -def setup(app): - app.add_config_value( - "recommonmark_config", - { - "enable_eval_rst": True, - "enable_auto_doc_ref": False, - }, - True, - ) - app.add_transform(AutoStructify) - - html_theme_options = { "style_external_links": False, } diff --git a/docs/sanic/changelog.rst b/docs/sanic/changelog.rst index fb389e054e..516b858762 100644 --- a/docs/sanic/changelog.rst +++ b/docs/sanic/changelog.rst @@ -1,4 +1,6 @@ 📜 Changelog ============ +.. mdinclude:: ./releases/21.9.md + .. include:: ../../CHANGELOG.rst diff --git a/docs/sanic/contributing.rst b/docs/sanic/contributing.rst index 91cfd11e8b..5d21caa23a 100644 --- a/docs/sanic/contributing.rst +++ b/docs/sanic/contributing.rst @@ -1,4 +1,4 @@ ♥️ Contributing -============== +=============== .. include:: ../../CONTRIBUTING.rst diff --git a/docs/sanic/releases/21.9.md b/docs/sanic/releases/21.9.md new file mode 100644 index 0000000000..8900340d5c --- /dev/null +++ b/docs/sanic/releases/21.9.md @@ -0,0 +1,40 @@ +## Version 21.9 + +### Features +- [#2158](https://github.com/sanic-org/sanic/pull/2158), [#2248](https://github.com/sanic-org/sanic/pull/2248) Complete overhaul of I/O to websockets +- [#2160](https://github.com/sanic-org/sanic/pull/2160) Add new 17 signals into server and request lifecycles +- [#2162](https://github.com/sanic-org/sanic/pull/2162) Smarter `auto` fallback formatting upon exception +- [#2184](https://github.com/sanic-org/sanic/pull/2184) Introduce implementation for copying a Blueprint +- [#2200](https://github.com/sanic-org/sanic/pull/2200) Accept header parsing +- [#2207](https://github.com/sanic-org/sanic/pull/2207) Log remote address if available +- [#2209](https://github.com/sanic-org/sanic/pull/2209) Add convenience methods to BP groups +- [#2216](https://github.com/sanic-org/sanic/pull/2216) Add default messages to SanicExceptions +- [#2225](https://github.com/sanic-org/sanic/pull/2225) Type annotation convenience for annotated handlers with path parameters +- [#2236](https://github.com/sanic-org/sanic/pull/2236) Allow Falsey (but not-None) responses from route handlers +- [#2238](https://github.com/sanic-org/sanic/pull/2238) Add `exception` decorator to Blueprint Groups +- [#2244](https://github.com/sanic-org/sanic/pull/2244) Explicit static directive for serving file or dir (ex: `static(..., resource_type="file")`) +- [#2245](https://github.com/sanic-org/sanic/pull/2245) Close HTTP loop when connection task cancelled + +### Bugfixes +- [#2188](https://github.com/sanic-org/sanic/pull/2188) Fix the handling of the end of a chunked request +- [#2195](https://github.com/sanic-org/sanic/pull/2195) Resolve unexpected error handling on static requests +- [#2208](https://github.com/sanic-org/sanic/pull/2208) Make blueprint-based exceptions attach and trigger in a more intuitive manner +- [#2211](https://github.com/sanic-org/sanic/pull/2211) Fixed for handling exceptions of asgi app call +- [#2213](https://github.com/sanic-org/sanic/pull/2213) Fix bug where ws exceptions not being logged +- [#2231](https://github.com/sanic-org/sanic/pull/2231) Cleaner closing of tasks by using `abort()` in strategic places to avoid dangling sockets +- [#2247](https://github.com/sanic-org/sanic/pull/2247) Fix logging of auto-reload status in debug mode +- [#2246](https://github.com/sanic-org/sanic/pull/2246) Account for BP with exception handler but no routes + +### Developer infrastructure +- [#2194](https://github.com/sanic-org/sanic/pull/2194) HTTP unit tests with raw client +- [#2199](https://github.com/sanic-org/sanic/pull/2199) Switch to codeclimate +- [#2214](https://github.com/sanic-org/sanic/pull/2214) Try Reopening Windows Tests +- [#2229](https://github.com/sanic-org/sanic/pull/2229) Refactor `HttpProtocol` into a base class +- [#2230](https://github.com/sanic-org/sanic/pull/2230) Refactor `server.py` into multi-file module + +### Miscellaneous +- [#2173](https://github.com/sanic-org/sanic/pull/2173) Remove Duplicated Dependencies and PEP 517 Support +- [#2193](https://github.com/sanic-org/sanic/pull/2193), [#2196](https://github.com/sanic-org/sanic/pull/2196), [#2217](https://github.com/sanic-org/sanic/pull/2217) Type annotation changes + + + diff --git a/examples/run_async_advanced.py b/examples/run_async_advanced.py index 36027c2f08..27f86f3f64 100644 --- a/examples/run_async_advanced.py +++ b/examples/run_async_advanced.py @@ -1,29 +1,44 @@ -from sanic import Sanic -from sanic import response -from signal import signal, SIGINT import asyncio + +from signal import SIGINT, signal + import uvloop +from sanic import Sanic, response +from sanic.server import AsyncioServer + + app = Sanic(__name__) -@app.listener('after_server_start') + +@app.listener("after_server_start") async def after_start_test(app, loop): print("Async Server Started!") + @app.route("/") async def test(request): return response.json({"answer": "42"}) + asyncio.set_event_loop(uvloop.new_event_loop()) -serv_coro = app.create_server(host="0.0.0.0", port=8000, return_asyncio_server=True) +serv_coro = app.create_server( + host="0.0.0.0", port=8000, return_asyncio_server=True +) loop = asyncio.get_event_loop() serv_task = asyncio.ensure_future(serv_coro, loop=loop) signal(SIGINT, lambda s, f: loop.stop()) -server = loop.run_until_complete(serv_task) +server: AsyncioServer = loop.run_until_complete(serv_task) # type: ignore +server.startup() + +# When using app.run(), this actually triggers before the serv_coro. +# But, in this example, we are using the convenience method, even if it is +# out of order. +server.before_start() server.after_start() try: loop.run_forever() -except KeyboardInterrupt as e: +except KeyboardInterrupt: loop.stop() finally: server.before_stop() diff --git a/examples/websocket.py b/examples/websocket.py index 9cba083cfc..92f713756b 100644 --- a/examples/websocket.py +++ b/examples/websocket.py @@ -1,13 +1,14 @@ from sanic import Sanic -from sanic.response import file +from sanic.response import redirect app = Sanic(__name__) -@app.route('/') -async def index(request): - return await file('websocket.html') +app.static('index.html', "websocket.html") +@app.route('/') +def index(request): + return redirect("index.html") @app.websocket('/feed') async def feed(request, ws): diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..9787c3bdf0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/sanic/__version__.py b/sanic/__version__.py index 74a495e2fe..325664388b 100644 --- a/sanic/__version__.py +++ b/sanic/__version__.py @@ -1 +1 @@ -__version__ = "21.6.2" +__version__ = "21.9.0" diff --git a/sanic/app.py b/sanic/app.py index ec9027b512..01aa07cb0b 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import logging import logging.config import os import re from asyncio import ( + AbstractEventLoop, CancelledError, Protocol, ensure_future, @@ -21,6 +24,7 @@ from types import SimpleNamespace from typing import ( Any, + AnyStr, Awaitable, Callable, Coroutine, @@ -30,6 +34,7 @@ List, Optional, Set, + Tuple, Type, Union, ) @@ -69,20 +74,29 @@ from sanic.server import AsyncioServer, HttpProtocol from sanic.server import Signal as ServerSignal from sanic.server import serve, serve_multiple, serve_single +from sanic.server.protocols.websocket_protocol import WebSocketProtocol +from sanic.server.websockets.impl import ConnectionClosed from sanic.signals import Signal, SignalRouter -from sanic.websocket import ConnectionClosed, WebSocketProtocol +from sanic.touchup import TouchUp, TouchUpMeta -class Sanic(BaseSanic): +class Sanic(BaseSanic, metaclass=TouchUpMeta): """ The main application instance """ + __touchup__ = ( + "handle_request", + "handle_exception", + "_run_response_middleware", + "_run_request_middleware", + ) __fake_slots__ = ( "_asgi_app", "_app_registry", "_asgi_client", "_blueprint_order", + "_delayed_tasks", "_future_routes", "_future_statics", "_future_middleware", @@ -137,7 +151,7 @@ def __init__( log_config: Optional[Dict[str, Any]] = None, configure_logging: bool = True, register: Optional[bool] = None, - dumps: Optional[Callable[..., str]] = None, + dumps: Optional[Callable[..., AnyStr]] = None, ) -> None: super().__init__(name=name) @@ -153,6 +167,7 @@ def __init__( self._asgi_client = None self._blueprint_order: List[Blueprint] = [] + self._delayed_tasks: List[str] = [] self._test_client = None self._test_manager = None self.asgi = False @@ -164,7 +179,9 @@ def __init__( self.configure_logging = configure_logging self.ctx = ctx or SimpleNamespace() self.debug = None - self.error_handler = error_handler or ErrorHandler() + self.error_handler = error_handler or ErrorHandler( + fallback=self.config.FALLBACK_ERROR_FORMAT, + ) self.is_running = False self.is_stopping = False self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) @@ -190,9 +207,10 @@ def __init__( self.__class__.register_app(self) self.router.ctx.app = self + self.signal_router.ctx.app = self if dumps: - BaseHTTPResponse._dumps = dumps + BaseHTTPResponse._dumps = dumps # type: ignore @property def loop(self): @@ -230,9 +248,12 @@ def add_task(self, task) -> None: loop = self.loop # Will raise SanicError if loop is not started self._loop_add_task(task, self, loop) except SanicException: - self.listener("before_server_start")( - partial(self._loop_add_task, task) - ) + task_name = f"sanic.delayed_task.{hash(task)}" + if not self._delayed_tasks: + self.after_server_start(partial(self.dispatch_delayed_tasks)) + + self.signal(task_name)(partial(self.run_delayed_task, task=task)) + self._delayed_tasks.append(task_name) def register_listener(self, listener: Callable, event: str) -> Any: """ @@ -244,12 +265,20 @@ def register_listener(self, listener: Callable, event: str) -> Any: """ try: - _event = ListenerEvent(event) - except ValueError: - valid = ", ".join(ListenerEvent.__members__.values()) + _event = ListenerEvent[event.upper()] + except (ValueError, AttributeError): + valid = ", ".join( + map(lambda x: x.lower(), ListenerEvent.__members__.keys()) + ) raise InvalidUsage(f"Invalid event: {event}. Use one of: {valid}") - self.listeners[_event].append(listener) + if "." in _event: + self.signal(_event.value)( + partial(self._listener, listener=listener) + ) + else: + self.listeners[_event.value].append(listener) + return listener def register_middleware(self, middleware, attach_to: str = "request"): @@ -308,7 +337,11 @@ def register_named_middleware( self.named_response_middleware[_rn].appendleft(middleware) return middleware - def _apply_exception_handler(self, handler: FutureException): + def _apply_exception_handler( + self, + handler: FutureException, + route_names: Optional[List[str]] = None, + ): """Decorate a function to be registered as a handler for exceptions :param exceptions: exceptions @@ -318,9 +351,9 @@ def _apply_exception_handler(self, handler: FutureException): for exception in handler.exceptions: if isinstance(exception, (tuple, list)): for e in exception: - self.error_handler.add(e, handler.handler) + self.error_handler.add(e, handler.handler, route_names) else: - self.error_handler.add(exception, handler.handler) + self.error_handler.add(exception, handler.handler, route_names) return handler.handler def _apply_listener(self, listener: FutureListener): @@ -377,11 +410,17 @@ def dispatch( *, condition: Optional[Dict[str, str]] = None, context: Optional[Dict[str, Any]] = None, + fail_not_found: bool = True, + inline: bool = False, + reverse: bool = False, ) -> Coroutine[Any, Any, Awaitable[Any]]: return self.signal_router.dispatch( event, context=context, condition=condition, + inline=inline, + reverse=reverse, + fail_not_found=fail_not_found, ) async def event( @@ -411,7 +450,13 @@ def enable_websocket(self, enable=True): self.websocket_enabled = enable - def blueprint(self, blueprint, **options): + def blueprint( + self, + blueprint: Union[ + Blueprint, List[Blueprint], Tuple[Blueprint], BlueprintGroup + ], + **options: Any, + ): """Register a blueprint on the application. :param blueprint: Blueprint object or (list, tuple) thereof @@ -651,7 +696,7 @@ def url_for(self, view_name: str, **kwargs): async def handle_exception( self, request: Request, exception: BaseException - ): + ): # no cov """ A handler that catches specific exceptions and outputs a response. @@ -661,6 +706,12 @@ async def handle_exception( :type exception: BaseException :raises ServerError: response 500 """ + await self.dispatch( + "http.lifecycle.exception", + inline=True, + context={"request": request, "exception": exception}, + ) + # -------------------------------------------- # # Request Middleware # -------------------------------------------- # @@ -707,7 +758,7 @@ async def handle_exception( f"Invalid response type {response!r} (need HTTPResponse)" ) - async def handle_request(self, request: Request): + async def handle_request(self, request: Request): # no cov """Take a request from the HTTP Server and return a response object to be sent back The HTTP Server only expects a response object, so exception handling must be done here @@ -715,10 +766,22 @@ async def handle_request(self, request: Request): :param request: HTTP Request object :return: Nothing """ + await self.dispatch( + "http.lifecycle.handle", + inline=True, + context={"request": request}, + ) + # Define `response` var here to remove warnings about # allocation before assignment below. response = None try: + + await self.dispatch( + "http.routing.before", + inline=True, + context={"request": request}, + ) # Fetch handler from router route, handler, kwargs = self.router.get( request.path, @@ -726,19 +789,29 @@ async def handle_request(self, request: Request): request.headers.getone("host", None), ) - request._match_info = kwargs + request._match_info = {**kwargs} request.route = route + await self.dispatch( + "http.routing.after", + inline=True, + context={ + "request": request, + "route": route, + "kwargs": kwargs, + "handler": handler, + }, + ) + if ( - request.stream.request_body # type: ignore + request.stream + and request.stream.request_body and not route.ctx.ignore_body ): if hasattr(handler, "is_stream"): # Streaming handler: lift the size limit - request.stream.request_max_size = float( # type: ignore - "inf" - ) + request.stream.request_max_size = float("inf") else: # Non-streaming handler: preload body await request.receive_body() @@ -765,17 +838,25 @@ async def handle_request(self, request: Request): ) # Run response handler - response = handler(request, **kwargs) + response = handler(request, **request.match_info) if isawaitable(response): response = await response - if response: + if response is not None: response = await request.respond(response) elif not hasattr(handler, "is_websocket"): response = request.stream.response # type: ignore # Make sure that response is finished / run StreamingHTTP callback if isinstance(response, BaseHTTPResponse): + await self.dispatch( + "http.lifecycle.response", + inline=True, + context={ + "request": request, + "response": response, + }, + ) await response.send(end_stream=True) else: if not hasattr(handler, "is_websocket"): @@ -793,23 +874,11 @@ async def handle_request(self, request: Request): async def _websocket_handler( self, handler, request, *args, subprotocols=None, **kwargs ): - request.app = self - if not getattr(handler, "__blueprintname__", False): - request._name = handler.__name__ - else: - request._name = ( - getattr(handler, "__blueprintname__", "") + handler.__name__ - ) - - pass - if self.asgi: ws = request.transport.get_websocket_connection() await ws.accept(subprotocols) else: protocol = request.transport.get_protocol() - protocol.app = self - ws = await protocol.websocket_handshake(request, subprotocols) # schedule the application handler @@ -817,13 +886,19 @@ async def _websocket_handler( # needs to be cancelled due to the server being stopped fut = ensure_future(handler(request, ws, *args, **kwargs)) self.websocket_tasks.add(fut) + cancelled = False try: await fut + except Exception as e: + self.error_handler.log(request, e) except (CancelledError, ConnectionClosed): - pass + cancelled = True finally: self.websocket_tasks.remove(fut) - await ws.close() + if cancelled: + ws.end_connection(1000) + else: + await ws.close() # -------------------------------------------------------------------- # # Testing @@ -869,7 +944,7 @@ def run( *, debug: bool = False, auto_reload: Optional[bool] = None, - ssl: Union[dict, SSLContext, None] = None, + ssl: Union[Dict[str, str], SSLContext, None] = None, sock: Optional[socket] = None, workers: int = 1, protocol: Optional[Type[Protocol]] = None, @@ -999,7 +1074,7 @@ async def create_server( port: Optional[int] = None, *, debug: bool = False, - ssl: Union[dict, SSLContext, None] = None, + ssl: Union[Dict[str, str], SSLContext, None] = None, sock: Optional[socket] = None, protocol: Type[Protocol] = None, backlog: int = 100, @@ -1071,11 +1146,6 @@ async def create_server( run_async=return_asyncio_server, ) - # Trigger before_start events - await self.trigger_events( - server_settings.get("before_start", []), - server_settings.get("loop"), - ) main_start = server_settings.pop("main_start", None) main_stop = server_settings.pop("main_stop", None) if main_start or main_stop: @@ -1088,17 +1158,9 @@ async def create_server( asyncio_server_kwargs=asyncio_server_kwargs, **server_settings ) - async def trigger_events(self, events, loop): - """Trigger events (functions or async) - :param events: one or more sync or async functions to execute - :param loop: event loop - """ - for event in events: - result = event(loop) - if isawaitable(result): - await result - - async def _run_request_middleware(self, request, request_name=None): + async def _run_request_middleware( + self, request, request_name=None + ): # no cov # The if improves speed. I don't know why named_middleware = self.named_request_middleware.get( request_name, deque() @@ -1111,25 +1173,67 @@ async def _run_request_middleware(self, request, request_name=None): request.request_middleware_started = True for middleware in applicable_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) + response = middleware(request) if isawaitable(response): response = await response + + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": None, + }, + condition={"attach_to": "request"}, + ) + if response: return response return None async def _run_response_middleware( self, request, response, request_name=None - ): + ): # no cov named_middleware = self.named_response_middleware.get( request_name, deque() ) applicable_middleware = self.response_middleware + named_middleware if applicable_middleware: for middleware in applicable_middleware: + await self.dispatch( + "http.middleware.before", + inline=True, + context={ + "request": request, + "response": response, + }, + condition={"attach_to": "response"}, + ) + _response = middleware(request, response) if isawaitable(_response): _response = await _response + + await self.dispatch( + "http.middleware.after", + inline=True, + context={ + "request": request, + "response": _response if _response else response, + }, + condition={"attach_to": "response"}, + ) + if _response: response = _response if isinstance(response, BaseHTTPResponse): @@ -1155,10 +1259,6 @@ def _helper( ): """Helper function used by `run` and `create_server`.""" - self.listeners["before_server_start"] = [ - self.finalize - ] + self.listeners["before_server_start"] - if isinstance(ssl, dict): # try common aliaseses cert = ssl.get("cert") or ssl.get("certificate") @@ -1195,10 +1295,6 @@ def _helper( # Register start/stop events for event_name, settings_name, reverse in ( - ("before_server_start", "before_start", False), - ("after_server_start", "after_start", False), - ("before_server_stop", "before_stop", True), - ("after_server_stop", "after_stop", True), ("main_process_start", "main_start", False), ("main_process_stop", "main_stop", True), ): @@ -1236,7 +1332,8 @@ def _helper( logger.info(f"Goin' Fast @ {proto}://{host}:{port}") debug_mode = "enabled" if self.debug else "disabled" - logger.debug("Sanic auto-reload: enabled") + reload_mode = "enabled" if auto_reload else "disabled" + logger.debug(f"Sanic auto-reload: {reload_mode}") logger.debug(f"Sanic debug mode: {debug_mode}") return server_settings @@ -1246,20 +1343,44 @@ def _build_endpoint_name(self, *parts): return ".".join(parts) @classmethod - def _loop_add_task(cls, task, app, loop): + def _prep_task(cls, task, app, loop): if callable(task): try: - loop.create_task(task(app)) + task = task(app) except TypeError: - loop.create_task(task()) - else: - loop.create_task(task) + task = task() + + return task + + @classmethod + def _loop_add_task(cls, task, app, loop): + prepped = cls._prep_task(task, app, loop) + loop.create_task(prepped) @classmethod def _cancel_websocket_tasks(cls, app, loop): for task in app.websocket_tasks: task.cancel() + @staticmethod + async def dispatch_delayed_tasks(app, loop): + for name in app._delayed_tasks: + await app.dispatch(name, context={"app": app, "loop": loop}) + app._delayed_tasks.clear() + + @staticmethod + async def run_delayed_task(app, loop, task): + prepped = app._prep_task(task, app, loop) + await prepped + + @staticmethod + async def _listener( + app: Sanic, loop: AbstractEventLoop, listener: ListenerType + ): + maybe_coro = listener(app, loop) + if maybe_coro and isawaitable(maybe_coro): + await maybe_coro + # -------------------------------------------------------------------- # # ASGI # -------------------------------------------------------------------- # @@ -1333,15 +1454,51 @@ def get_app( raise SanicException(f'Sanic app name "{name}" not found.') # -------------------------------------------------------------------- # - # Static methods + # Lifecycle # -------------------------------------------------------------------- # - @staticmethod - async def finalize(app, _): + def finalize(self): + try: + self.router.finalize() + except FinalizationError as e: + if not Sanic.test_mode: + raise e + + def signalize(self): try: - app.router.finalize() - if app.signal_router.routes: - app.signal_router.finalize() # noqa + self.signal_router.finalize() except FinalizationError as e: if not Sanic.test_mode: - raise e # noqa + raise e + + async def _startup(self): + self.signalize() + self.finalize() + TouchUp.run(self) + + async def _server_event( + self, + concern: str, + action: str, + loop: Optional[AbstractEventLoop] = None, + ) -> None: + event = f"server.{concern}.{action}" + if action not in ("before", "after") or concern not in ( + "init", + "shutdown", + ): + raise SanicException(f"Invalid server event: {event}") + logger.debug(f"Triggering server events: {event}") + reverse = concern == "shutdown" + if loop is None: + loop = self.loop + await self.dispatch( + event, + fail_not_found=False, + reverse=reverse, + inline=True, + context={ + "app": self, + "loop": loop, + }, + ) diff --git a/sanic/asgi.py b/sanic/asgi.py index 330ced5a4f..55c18d5cf5 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -1,6 +1,5 @@ import warnings -from inspect import isawaitable from typing import Optional from urllib.parse import quote @@ -11,21 +10,27 @@ from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.request import Request from sanic.server import ConnInfo -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection class Lifespan: def __init__(self, asgi_app: "ASGIApp") -> None: self.asgi_app = asgi_app - if "before_server_start" in self.asgi_app.sanic_app.listeners: + if ( + "server.init.before" + in self.asgi_app.sanic_app.signal_router.name_index + ): warnings.warn( 'You have set a listener for "before_server_start" ' "in ASGI mode. " "It will be executed as early as possible, but not before " "the ASGI server is started." ) - if "after_server_stop" in self.asgi_app.sanic_app.listeners: + if ( + "server.shutdown.after" + in self.asgi_app.sanic_app.signal_router.name_index + ): warnings.warn( 'You have set a listener for "after_server_stop" ' "in ASGI mode. " @@ -42,19 +47,9 @@ async def startup(self) -> None: in sequence since the ASGI lifespan protocol only supports a single startup event. """ - self.asgi_app.sanic_app.router.finalize() - if self.asgi_app.sanic_app.signal_router.routes: - self.asgi_app.sanic_app.signal_router.finalize() - listeners = self.asgi_app.sanic_app.listeners.get( - "before_server_start", [] - ) + self.asgi_app.sanic_app.listeners.get("after_server_start", []) - - for handler in listeners: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if response and isawaitable(response): - await response + await self.asgi_app.sanic_app._startup() + await self.asgi_app.sanic_app._server_event("init", "before") + await self.asgi_app.sanic_app._server_event("init", "after") async def shutdown(self) -> None: """ @@ -65,16 +60,8 @@ async def shutdown(self) -> None: in sequence since the ASGI lifespan protocol only supports a single shutdown event. """ - listeners = self.asgi_app.sanic_app.listeners.get( - "before_server_stop", [] - ) + self.asgi_app.sanic_app.listeners.get("after_server_stop", []) - - for handler in listeners: - response = handler( - self.asgi_app.sanic_app, self.asgi_app.sanic_app.loop - ) - if response and isawaitable(response): - await response + await self.asgi_app.sanic_app._server_event("shutdown", "before") + await self.asgi_app.sanic_app._server_event("shutdown", "after") async def __call__( self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend diff --git a/sanic/base.py b/sanic/base.py index ff4833fa94..5d1358d830 100644 --- a/sanic/base.py +++ b/sanic/base.py @@ -58,7 +58,7 @@ def __setattr__(self, name: str, value: Any) -> None: if name not in self.__fake_slots__: warn( f"Setting variables on {self.__class__.__name__} instances is " - "deprecated and will be removed in version 21.9. You should " + "deprecated and will be removed in version 21.12. You should " f"change your {self.__class__.__name__} instance to use " f"instance.ctx.{name} instead.", DeprecationWarning, diff --git a/sanic/blueprint_group.py b/sanic/blueprint_group.py index 45f3089481..8bec376d64 100644 --- a/sanic/blueprint_group.py +++ b/sanic/blueprint_group.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import MutableSequence +from functools import partial from typing import TYPE_CHECKING, List, Optional, Union @@ -196,6 +197,27 @@ def append(self, value: Blueprint) -> None: """ self._blueprints.append(value) + def exception(self, *exceptions, **kwargs): + """ + A decorator that can be used to implement a global exception handler + for all the Blueprints that belong to this Blueprint Group. + + In case of nested Blueprint Groups, the same handler is applied + across each of the Blueprints recursively. + + :param args: List of Python exceptions to be caught by the handler + :param kwargs: Additional optional arguments to be passed to the + exception handler + :return a decorated method to handle global exceptions for any + blueprint registered under this group. + """ + + def register_exception_handler_for_blueprints(fn): + for blueprint in self.blueprints: + blueprint.exception(*exceptions, **kwargs)(fn) + + return register_exception_handler_for_blueprints + def insert(self, index: int, item: Blueprint) -> None: """ The Abstract class `MutableSequence` leverages this insert method to @@ -229,3 +251,15 @@ def register_middleware_for_blueprints(fn): args = list(args)[1:] return register_middleware_for_blueprints(fn) return register_middleware_for_blueprints + + def on_request(self, middleware=None): + if callable(middleware): + return self.middleware(middleware, "request") + else: + return partial(self.middleware, attach_to="request") + + def on_response(self, middleware=None): + if callable(middleware): + return self.middleware(middleware, "response") + else: + return partial(self.middleware, attach_to="response") diff --git a/sanic/blueprints.py b/sanic/blueprints.py index 2431f8497a..617ec6060b 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -3,6 +3,7 @@ import asyncio from collections import defaultdict +from copy import deepcopy from types import SimpleNamespace from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Union @@ -12,6 +13,7 @@ from sanic.base import BaseSanic from sanic.blueprint_group import BlueprintGroup from sanic.exceptions import SanicException +from sanic.helpers import Default, _default from sanic.models.futures import FutureRoute, FutureStatic from sanic.models.handler_types import ( ListenerType, @@ -40,7 +42,7 @@ class Blueprint(BaseSanic): :param host: IP Address of FQDN for the sanic server to use. :param version: Blueprint Version :param strict_slashes: Enforce the API urls are requested with a - training */* + trailing */* """ __fake_slots__ = ( @@ -76,15 +78,9 @@ def __init__( version_prefix: str = "/v", ): super().__init__(name=name) - - self._apps: Set[Sanic] = set() + self.reset() self.ctx = SimpleNamespace() - self.exceptions: List[RouteHandler] = [] self.host = host - self.listeners: Dict[str, List[ListenerType]] = {} - self.middlewares: List[MiddlewareType] = [] - self.routes: List[Route] = [] - self.statics: List[RouteHandler] = [] self.strict_slashes = strict_slashes self.url_prefix = ( url_prefix[:-1] @@ -93,7 +89,6 @@ def __init__( ) self.version = version self.version_prefix = version_prefix - self.websocket_routes: List[Route] = [] def __repr__(self) -> str: args = ", ".join( @@ -144,12 +139,87 @@ def signal(self, event: str, *args, **kwargs): kwargs["apply"] = False return super().signal(event, *args, **kwargs) + def reset(self): + self._apps: Set[Sanic] = set() + self.exceptions: List[RouteHandler] = [] + self.listeners: Dict[str, List[ListenerType]] = {} + self.middlewares: List[MiddlewareType] = [] + self.routes: List[Route] = [] + self.statics: List[RouteHandler] = [] + self.websocket_routes: List[Route] = [] + + def copy( + self, + name: str, + url_prefix: Optional[Union[str, Default]] = _default, + version: Optional[Union[int, str, float, Default]] = _default, + version_prefix: Union[str, Default] = _default, + strict_slashes: Optional[Union[bool, Default]] = _default, + with_registration: bool = True, + with_ctx: bool = False, + ): + """ + Copy a blueprint instance with some optional parameters to + override the values of attributes in the old instance. + + :param name: unique name of the blueprint + :param url_prefix: URL to be prefixed before all route URLs + :param version: Blueprint Version + :param version_prefix: the prefix of the version number shown in the + URL. + :param strict_slashes: Enforce the API urls are requested with a + trailing */* + :param with_registration: whether register new blueprint instance with + sanic apps that were registered with the old instance or not. + :param with_ctx: whether ``ctx`` will be copied or not. + """ + + attrs_backup = { + "_apps": self._apps, + "routes": self.routes, + "websocket_routes": self.websocket_routes, + "middlewares": self.middlewares, + "exceptions": self.exceptions, + "listeners": self.listeners, + "statics": self.statics, + } + + self.reset() + new_bp = deepcopy(self) + new_bp.name = name + + if not isinstance(url_prefix, Default): + new_bp.url_prefix = url_prefix + if not isinstance(version, Default): + new_bp.version = version + if not isinstance(strict_slashes, Default): + new_bp.strict_slashes = strict_slashes + if not isinstance(version_prefix, Default): + new_bp.version_prefix = version_prefix + + for key, value in attrs_backup.items(): + setattr(self, key, value) + + if with_registration and self._apps: + if new_bp._future_statics: + raise SanicException( + "Static routes registered with the old blueprint instance," + " cannot be registered again." + ) + for app in self._apps: + app.blueprint(new_bp) + + if not with_ctx: + new_bp.ctx = SimpleNamespace() + + return new_bp + @staticmethod def group( - *blueprints, - url_prefix="", - version=None, - strict_slashes=None, + *blueprints: Union[Blueprint, BlueprintGroup], + url_prefix: Optional[str] = None, + version: Optional[Union[int, str, float]] = None, + strict_slashes: Optional[bool] = None, version_prefix: str = "/v", ): """ @@ -196,6 +266,9 @@ def register(self, app, options): opt_version = options.get("version", None) opt_strict_slashes = options.get("strict_slashes", None) opt_version_prefix = options.get("version_prefix", self.version_prefix) + error_format = options.get( + "error_format", app.config.FALLBACK_ERROR_FORMAT + ) routes = [] middleware = [] @@ -243,6 +316,7 @@ def register(self, app, options): future.unquote, future.static, version_prefix, + error_format, ) route = app._apply_route(apply_route) @@ -261,19 +335,22 @@ def register(self, app, options): route_names = [route.name for route in routes if route] - # Middleware if route_names: + # Middleware for future in self._future_middleware: middleware.append(app._apply_middleware(future, route_names)) - # Exceptions - for future in self._future_exceptions: - exception_handlers.append(app._apply_exception_handler(future)) + # Exceptions + for future in self._future_exceptions: + exception_handlers.append( + app._apply_exception_handler(future, route_names) + ) # Event listeners for listener in self._future_listeners: listeners[listener.event].append(app._apply_listener(listener)) + # Signals for signal in self._future_signals: signal.condition.update({"blueprint": self.name}) app._apply_signal(signal) diff --git a/sanic/config.py b/sanic/config.py index 27699f8067..649d9414bc 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional, Union from warnings import warn +from sanic.errorpages import check_error_format from sanic.http import Http from .utils import load_module_from_file_location, str_to_bool @@ -20,7 +21,7 @@ DEFAULT_CONFIG = { "ACCESS_LOG": True, "EVENT_AUTOREGISTER": False, - "FALLBACK_ERROR_FORMAT": "html", + "FALLBACK_ERROR_FORMAT": "auto", "FORWARDED_FOR_HEADER": "X-Forwarded-For", "FORWARDED_SECRET": None, "GRACEFUL_SHUTDOWN_TIMEOUT": 15.0, # 15 sec @@ -35,12 +36,9 @@ "REQUEST_MAX_SIZE": 100000000, # 100 megabytes "REQUEST_TIMEOUT": 60, # 60 seconds "RESPONSE_TIMEOUT": 60, # 60 seconds - "WEBSOCKET_MAX_QUEUE": 32, "WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte "WEBSOCKET_PING_INTERVAL": 20, "WEBSOCKET_PING_TIMEOUT": 20, - "WEBSOCKET_READ_LIMIT": 2 ** 16, - "WEBSOCKET_WRITE_LIMIT": 2 ** 16, } @@ -62,12 +60,10 @@ class Config(dict): REQUEST_MAX_SIZE: int REQUEST_TIMEOUT: int RESPONSE_TIMEOUT: int - WEBSOCKET_MAX_QUEUE: int + SERVER_NAME: str WEBSOCKET_MAX_SIZE: int WEBSOCKET_PING_INTERVAL: int WEBSOCKET_PING_TIMEOUT: int - WEBSOCKET_READ_LIMIT: int - WEBSOCKET_WRITE_LIMIT: int def __init__( self, @@ -100,6 +96,7 @@ def __init__( self.load_environment_vars(SANIC_PREFIX) self._configure_header_size() + self._check_error_format() def __getattr__(self, attr): try: @@ -115,6 +112,8 @@ def __setattr__(self, attr, value): "REQUEST_MAX_SIZE", ): self._configure_header_size() + elif attr == "FALLBACK_ERROR_FORMAT": + self._check_error_format() def _configure_header_size(self): Http.set_header_max_size( @@ -123,6 +122,9 @@ def _configure_header_size(self): self.REQUEST_MAX_SIZE, ) + def _check_error_format(self): + check_error_format(self.FALLBACK_ERROR_FORMAT) + def load_environment_vars(self, prefix=SANIC_PREFIX): """ Looks for prefixed environment variables and applies diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 5fc10de153..82cdd57a5c 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -340,41 +340,138 @@ def escape(text): } RENDERERS_BY_CONTENT_TYPE = { - "multipart/form-data": HTMLRenderer, - "application/json": JSONRenderer, "text/plain": TextRenderer, + "application/json": JSONRenderer, + "multipart/form-data": HTMLRenderer, + "text/html": HTMLRenderer, +} +CONTENT_TYPE_BY_RENDERERS = { + v: k for k, v in RENDERERS_BY_CONTENT_TYPE.items() } +RESPONSE_MAPPING = { + "empty": "html", + "json": "json", + "text": "text", + "raw": "text", + "html": "html", + "file": "html", + "file_stream": "text", + "stream": "text", + "redirect": "html", + "text/plain": "text", + "text/html": "html", + "application/json": "json", +} + + +def check_error_format(format): + if format not in RENDERERS_BY_CONFIG and format != "auto": + raise SanicException(f"Unknown format: {format}") + def exception_response( request: Request, exception: Exception, debug: bool, + fallback: str, + base: t.Type[BaseRenderer], renderer: t.Type[t.Optional[BaseRenderer]] = None, ) -> HTTPResponse: """ Render a response for the default FALLBACK exception handler. """ + content_type = None if not renderer: - renderer = HTMLRenderer + # Make sure we have something set + renderer = base + render_format = fallback if request: - if request.app.config.FALLBACK_ERROR_FORMAT == "auto": + # If there is a request, try and get the format + # from the route + if request.route: try: - renderer = JSONRenderer if request.json else HTMLRenderer - except InvalidUsage: + render_format = request.route.ctx.error_format + except AttributeError: + ... + + content_type = request.headers.getone("content-type", "").split( + ";" + )[0] + + acceptable = request.accept + + # If the format is auto still, make a guess + if render_format == "auto": + # First, if there is an Accept header, check if text/html + # is the first option + # According to MDN Web Docs, all major browsers use text/html + # as the primary value in Accept (with the exception of IE 8, + # and, well, if you are supporting IE 8, then you have bigger + # problems to concern yourself with than what default exception + # renderer is used) + # Source: + # https://developer.mozilla.org/en-US/docs/Web/HTTP/Content_negotiation/List_of_default_Accept_values + + if acceptable and acceptable[0].match( + "text/html", + allow_type_wildcard=False, + allow_subtype_wildcard=False, + ): renderer = HTMLRenderer - content_type, *_ = request.headers.getone( - "content-type", "" - ).split(";") - renderer = RENDERERS_BY_CONTENT_TYPE.get( - content_type, renderer - ) + # Second, if there is an Accept header, check if + # application/json is an option, or if the content-type + # is application/json + elif ( + acceptable + and acceptable.match( + "application/json", + allow_type_wildcard=False, + allow_subtype_wildcard=False, + ) + or content_type == "application/json" + ): + renderer = JSONRenderer + + # Third, if there is no Accept header, assume we want text. + # The likely use case here is a raw socket. + elif not acceptable: + renderer = TextRenderer + else: + # Fourth, look to see if there was a JSON body + # When in this situation, the request is probably coming + # from curl, an API client like Postman or Insomnia, or a + # package like requests or httpx + try: + # Give them the benefit of the doubt if they did: + # $ curl localhost:8000 -d '{"foo": "bar"}' + # And provide them with JSONRenderer + renderer = JSONRenderer if request.json else base + except InvalidUsage: + renderer = base else: - render_format = request.app.config.FALLBACK_ERROR_FORMAT renderer = RENDERERS_BY_CONFIG.get(render_format, renderer) + # Lastly, if there is an Accept header, make sure + # our choice is okay + if acceptable: + type_ = CONTENT_TYPE_BY_RENDERERS.get(renderer) # type: ignore + if type_ and type_ not in acceptable: + # If the renderer selected is not in the Accept header + # look through what is in the Accept header, and select + # the first option that matches. Otherwise, just drop back + # to the original default + for accept in acceptable: + mtype = f"{accept.type_}/{accept.subtype}" + maybe = RENDERERS_BY_CONTENT_TYPE.get(mtype) + if maybe: + renderer = maybe + break + else: + renderer = base + renderer = t.cast(t.Type[BaseRenderer], renderer) return renderer(request, exception, debug).render() diff --git a/sanic/exceptions.py b/sanic/exceptions.py index 16cd684d5b..1bb06f1de6 100644 --- a/sanic/exceptions.py +++ b/sanic/exceptions.py @@ -4,16 +4,20 @@ class SanicException(Exception): + message: str = "" + def __init__( self, message: Optional[Union[str, bytes]] = None, status_code: Optional[int] = None, quiet: Optional[bool] = None, ) -> None: - - if message is None and status_code is not None: - msg: bytes = STATUS_CODES.get(status_code, b"") - message = msg.decode("utf8") + if message is None: + if self.message: + message = self.message + elif status_code is not None: + msg: bytes = STATUS_CODES.get(status_code, b"") + message = msg.decode("utf8") super().__init__(message) @@ -122,8 +126,11 @@ class HeaderNotFound(InvalidUsage): **Status**: 400 Bad Request """ - status_code = 400 - quiet = True + +class InvalidHeader(InvalidUsage): + """ + **Status**: 400 Bad Request + """ class ContentRangeError(SanicException): @@ -230,6 +237,11 @@ class InvalidSignal(SanicException): pass +class WebsocketClosed(SanicException): + quiet = True + message = "Client has closed the websocket connection" + + def abort(status_code: int, message: Optional[Union[str, bytes]] = None): """ Raise an exception based on SanicException. Returns the HTTP response diff --git a/sanic/handlers.py b/sanic/handlers.py index dd1fbac118..ffeb76b8d8 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -1,12 +1,13 @@ -from traceback import format_exc +from typing import Dict, List, Optional, Tuple, Type -from sanic.errorpages import exception_response +from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response from sanic.exceptions import ( ContentRangeError, HeaderNotFound, InvalidRangeType, ) from sanic.log import error_logger +from sanic.models.handler_types import RouteHandler from sanic.response import text @@ -23,15 +24,17 @@ class ErrorHandler: """ - handlers = None - cached_handlers = None - - def __init__(self): - self.handlers = [] - self.cached_handlers = {} + # Beginning in v22.3, the base renderer will be TextRenderer + def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer): + self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] + self.cached_handlers: Dict[ + Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] + ] = {} self.debug = False + self.fallback = fallback + self.base = base - def add(self, exception, handler): + def add(self, exception, handler, route_names: Optional[List[str]] = None): """ Add a new exception handler to an already existing handler object. @@ -44,11 +47,16 @@ def add(self, exception, handler): :return: None """ - # self.handlers to be deprecated and removed in version 21.12 + # self.handlers is deprecated and will be removed in version 22.3 self.handlers.append((exception, handler)) - self.cached_handlers[exception] = handler - def lookup(self, exception): + if route_names: + for route in route_names: + self.cached_handlers[(exception, route)] = handler + else: + self.cached_handlers[(exception, None)] = handler + + def lookup(self, exception, route_name: Optional[str]): """ Lookup the existing instance of :class:`ErrorHandler` and fetch the registered handler for a specific type of exception. @@ -63,17 +71,26 @@ def lookup(self, exception): :return: Registered function if found ``None`` otherwise """ exception_class = type(exception) - if exception_class in self.cached_handlers: - return self.cached_handlers[exception_class] - for ancestor in type.mro(exception_class): - if ancestor in self.cached_handlers: - handler = self.cached_handlers[ancestor] - self.cached_handlers[exception_class] = handler + for name in (route_name, None): + exception_key = (exception_class, name) + handler = self.cached_handlers.get(exception_key) + if handler: return handler - if ancestor is BaseException: - break - self.cached_handlers[exception_class] = None + + for name in (route_name, None): + for ancestor in type.mro(exception_class): + exception_key = (ancestor, name) + if exception_key in self.cached_handlers: + handler = self.cached_handlers[exception_key] + self.cached_handlers[ + (exception_class, route_name) + ] = handler + return handler + + if ancestor is BaseException: + break + self.cached_handlers[(exception_class, route_name)] = None handler = None return handler @@ -91,7 +108,8 @@ def response(self, request, exception): :return: Wrap the return value obtained from :func:`default` or registered handler for that type of exception. """ - handler = self.lookup(exception) + route_name = request.name if request else None + handler = self.lookup(exception, route_name) response = None try: if handler: @@ -99,7 +117,6 @@ def response(self, request, exception): if response is None: response = self.default(request, exception) except Exception: - self.log(format_exc()) try: url = repr(request.url) except AttributeError: @@ -115,11 +132,6 @@ def response(self, request, exception): return text("An error occurred while handling an error", 500) return response - def log(self, message, level="error"): - """ - Deprecated, do not use. - """ - def default(self, request, exception): """ Provide a default behavior for the objects of :class:`ErrorHandler`. @@ -135,6 +147,17 @@ def default(self, request, exception): :class:`Exception` :return: """ + self.log(request, exception) + return exception_response( + request, + exception, + debug=self.debug, + base=self.base, + fallback=self.fallback, + ) + + @staticmethod + def log(request, exception): quiet = getattr(exception, "quiet", False) if quiet is False: try: @@ -142,13 +165,10 @@ def default(self, request, exception): except AttributeError: url = "unknown" - self.log(format_exc()) error_logger.exception( "Exception occurred while handling uri: %s", url ) - return exception_response(request, exception, self.debug) - class ContentRangeHandler: """ diff --git a/sanic/headers.py b/sanic/headers.py index 6642744228..dbb8720f9f 100644 --- a/sanic/headers.py +++ b/sanic/headers.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import re from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import unquote +from sanic.exceptions import InvalidHeader from sanic.helpers import STATUS_CODES @@ -30,6 +33,175 @@ # For more information, consult ../tests/test_requests.py +def parse_arg_as_accept(f): + def func(self, other, *args, **kwargs): + if not isinstance(other, Accept) and other: + other = Accept.parse(other) + return f(self, other, *args, **kwargs) + + return func + + +class MediaType(str): + def __new__(cls, value: str): + return str.__new__(cls, value) + + def __init__(self, value: str) -> None: + self.value = value + self.is_wildcard = self.check_if_wildcard(value) + + def __eq__(self, other): + if self.is_wildcard: + return True + + if self.match(other): + return True + + other_is_wildcard = ( + other.is_wildcard + if isinstance(other, MediaType) + else self.check_if_wildcard(other) + ) + + return other_is_wildcard + + def match(self, other): + other_value = other.value if isinstance(other, MediaType) else other + return self.value == other_value + + @staticmethod + def check_if_wildcard(value): + return value == "*" + + +class Accept(str): + def __new__(cls, value: str, *args, **kwargs): + return str.__new__(cls, value) + + def __init__( + self, + value: str, + type_: MediaType, + subtype: MediaType, + *, + q: str = "1.0", + **kwargs: str, + ): + qvalue = float(q) + if qvalue > 1 or qvalue < 0: + raise InvalidHeader( + f"Accept header qvalue must be between 0 and 1, not: {qvalue}" + ) + self.value = value + self.type_ = type_ + self.subtype = subtype + self.qvalue = qvalue + self.params = kwargs + + def _compare(self, other, method): + try: + return method(self.qvalue, other.qvalue) + except (AttributeError, TypeError): + return NotImplemented + + @parse_arg_as_accept + def __lt__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s < o) + + @parse_arg_as_accept + def __le__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s <= o) + + @parse_arg_as_accept + def __eq__(self, other: Union[str, Accept]): # type: ignore + return self._compare(other, lambda s, o: s == o) + + @parse_arg_as_accept + def __ge__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s >= o) + + @parse_arg_as_accept + def __gt__(self, other: Union[str, Accept]): + return self._compare(other, lambda s, o: s > o) + + @parse_arg_as_accept + def __ne__(self, other: Union[str, Accept]): # type: ignore + return self._compare(other, lambda s, o: s != o) + + @parse_arg_as_accept + def match( + self, + other, + *, + allow_type_wildcard: bool = True, + allow_subtype_wildcard: bool = True, + ) -> bool: + type_match = ( + self.type_ == other.type_ + if allow_type_wildcard + else ( + self.type_.match(other.type_) + and not self.type_.is_wildcard + and not other.type_.is_wildcard + ) + ) + subtype_match = ( + self.subtype == other.subtype + if allow_subtype_wildcard + else ( + self.subtype.match(other.subtype) + and not self.subtype.is_wildcard + and not other.subtype.is_wildcard + ) + ) + + return type_match and subtype_match + + @classmethod + def parse(cls, raw: str) -> Accept: + invalid = False + mtype = raw.strip() + + try: + media, *raw_params = mtype.split(";") + type_, subtype = media.split("/") + except ValueError: + invalid = True + + if invalid or not type_ or not subtype: + raise InvalidHeader(f"Header contains invalid Accept value: {raw}") + + params = dict( + [ + (key.strip(), value.strip()) + for key, value in (param.split("=", 1) for param in raw_params) + ] + ) + + return cls(mtype, MediaType(type_), MediaType(subtype), **params) + + +class AcceptContainer(list): + def __contains__(self, o: object) -> bool: + return any(item.match(o) for item in self) + + def match( + self, + o: object, + *, + allow_type_wildcard: bool = True, + allow_subtype_wildcard: bool = True, + ) -> bool: + return any( + item.match( + o, + allow_type_wildcard=allow_type_wildcard, + allow_subtype_wildcard=allow_subtype_wildcard, + ) + for item in self + ) + + def parse_content_header(value: str) -> Tuple[str, Options]: """Parse content-type and content-disposition header values. @@ -194,3 +366,31 @@ def format_http1_response(status: int, headers: HeaderBytesIterable) -> bytes: ret += b"%b: %b\r\n" % h ret += b"\r\n" return ret + + +def _sort_accept_value(accept: Accept): + return ( + accept.qvalue, + len(accept.params), + accept.subtype != "*", + accept.type_ != "*", + ) + + +def parse_accept(accept: str) -> AcceptContainer: + """Parse an Accept header and order the acceptable media types in + accorsing to RFC 7231, s. 5.3.2 + https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.2 + """ + media_types = accept.split(",") + accept_list: List[Accept] = [] + + for mtype in media_types: + if not mtype: + continue + + accept_list.append(Accept.parse(mtype)) + + return AcceptContainer( + sorted(accept_list, key=_sort_accept_value, reverse=True) + ) diff --git a/sanic/helpers.py b/sanic/helpers.py index 15ae7bf297..87d51b53ac 100644 --- a/sanic/helpers.py +++ b/sanic/helpers.py @@ -155,3 +155,17 @@ def import_string(module_name, package=None): if ismodule(obj): return obj return obj() + + +class Default: + """ + It is used to replace `None` or `object()` as a sentinel + that represents a default value. Sometimes we want to set + a value to `None` so we cannot use `None` to represent the + default value, and `object()` is hard to be typed. + """ + + pass + + +_default = Default() diff --git a/sanic/http.py b/sanic/http.py index 402238cbe2..d30e4c82b8 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -21,6 +21,7 @@ from sanic.headers import format_http1_response from sanic.helpers import has_message_body from sanic.log import access_logger, error_logger, logger +from sanic.touchup import TouchUpMeta class Stage(Enum): @@ -45,7 +46,7 @@ class Stage(Enum): HTTP_CONTINUE = b"HTTP/1.1 100 Continue\r\n\r\n" -class Http: +class Http(metaclass=TouchUpMeta): """ Internal helper for managing the HTTP request/response cycle @@ -67,9 +68,15 @@ class Http: HEADER_CEILING = 16_384 HEADER_MAX_SIZE = 0 + __touchup__ = ( + "http1_request_header", + "http1_response_header", + "read", + ) __slots__ = [ "_send", "_receive_more", + "dispatch", "recv_buffer", "protocol", "expecting_continue", @@ -97,6 +104,7 @@ def __init__(self, protocol): self.protocol = protocol self.keep_alive = True self.stage: Stage = Stage.IDLE + self.dispatch = self.protocol.app.dispatch self.init_for_request() def init_for_request(self): @@ -140,6 +148,12 @@ async def http1(self): await self.response.send(end_stream=True) except CancelledError: # Write an appropriate response before exiting + if not self.protocol.transport: + logger.info( + f"Request: {self.request.method} {self.request.url} " + "stopped. Transport is closed." + ) + return e = self.exception or ServiceUnavailable("Cancelled") self.exception = None self.keep_alive = False @@ -173,17 +187,17 @@ async def http1(self): if self.response: self.response.stream = None - self.init_for_request() - # Exit and disconnect if no more requests can be taken if self.stage is not Stage.IDLE or not self.keep_alive: break + self.init_for_request() + # Wait for the next request if not self.recv_buffer: await self._receive_more() - async def http1_request_header(self): + async def http1_request_header(self): # no cov """ Receive and parse request header into self.request. """ @@ -212,6 +226,12 @@ async def http1_request_header(self): reqline, *split_headers = raw_headers.split("\r\n") method, self.url, protocol = reqline.split(" ") + await self.dispatch( + "http.lifecycle.read_head", + inline=True, + context={"head": bytes(head)}, + ) + if protocol == "HTTP/1.1": self.keep_alive = True elif protocol == "HTTP/1.0": @@ -250,6 +270,11 @@ async def http1_request_header(self): transport=self.protocol.transport, app=self.protocol.app, ) + await self.dispatch( + "http.lifecycle.request", + inline=True, + context={"request": request}, + ) # Prepare for request body self.request_bytes_left = self.request_bytes = 0 @@ -280,7 +305,7 @@ async def http1_request_header(self): async def http1_response_header( self, data: bytes, end_stream: bool - ) -> None: + ) -> None: # no cov res = self.response # Compatibility with simple response body @@ -452,8 +477,8 @@ def log_response(self) -> None: "request": "nil", } if req is not None: - if req.ip: - extra["host"] = f"{req.ip}:{req.port}" + if req.remote_addr or req.ip: + extra["host"] = f"{req.remote_addr or req.ip}:{req.port}" extra["request"] = f"{req.method} {req.url}" access_logger.info("", extra=extra) @@ -469,7 +494,7 @@ async def __aiter__(self): if data: yield data - async def read(self) -> Optional[bytes]: + async def read(self) -> Optional[bytes]: # no cov """ Read some bytes of request body. """ @@ -543,6 +568,12 @@ async def read(self) -> Optional[bytes]: self.request_bytes_left -= size + await self.dispatch( + "http.lifecycle.read_body", + inline=True, + context={"body": data}, + ) + return data # Response methods diff --git a/sanic/mixins/listeners.py b/sanic/mixins/listeners.py index c12326c455..ebf9b13115 100644 --- a/sanic/mixins/listeners.py +++ b/sanic/mixins/listeners.py @@ -1,18 +1,19 @@ from enum import Enum, auto from functools import partial -from typing import Any, Callable, Coroutine, List, Optional, Union +from typing import List, Optional, Union from sanic.models.futures import FutureListener +from sanic.models.handler_types import ListenerType class ListenerEvent(str, Enum): def _generate_next_value_(name: str, *args) -> str: # type: ignore return name.lower() - BEFORE_SERVER_START = auto() - AFTER_SERVER_START = auto() - BEFORE_SERVER_STOP = auto() - AFTER_SERVER_STOP = auto() + BEFORE_SERVER_START = "server.init.before" + AFTER_SERVER_START = "server.init.after" + BEFORE_SERVER_STOP = "server.shutdown.before" + AFTER_SERVER_STOP = "server.shutdown.after" MAIN_PROCESS_START = auto() MAIN_PROCESS_STOP = auto() @@ -26,9 +27,7 @@ def _apply_listener(self, listener: FutureListener): def listener( self, - listener_or_event: Union[ - Callable[..., Coroutine[Any, Any, None]], str - ], + listener_or_event: Union[ListenerType, str], event_or_none: Optional[str] = None, apply: bool = True, ): @@ -63,20 +62,20 @@ def register_listener(listener, event): else: return partial(register_listener, event=listener_or_event) - def main_process_start(self, listener): + def main_process_start(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "main_process_start") - def main_process_stop(self, listener): + def main_process_stop(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "main_process_stop") - def before_server_start(self, listener): + def before_server_start(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "before_server_start") - def after_server_start(self, listener): + def after_server_start(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "after_server_start") - def before_server_stop(self, listener): + def before_server_stop(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "before_server_stop") - def after_server_stop(self, listener): + def after_server_stop(self, listener: ListenerType) -> ListenerType: return self.listener(listener, "after_server_stop") diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 5af1610d49..8467a2e340 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -1,17 +1,20 @@ +from ast import NodeVisitor, Return, parse from functools import partial, wraps -from inspect import signature +from inspect import getsource, signature from mimetypes import guess_type from os import path from pathlib import PurePath from re import sub +from textwrap import dedent from time import gmtime, strftime -from typing import Iterable, List, Optional, Set, Union +from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union from urllib.parse import unquote from sanic_routing.route import Route # type: ignore from sanic.compat import stat_async from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE, HTTP_METHODS +from sanic.errorpages import RESPONSE_MAPPING from sanic.exceptions import ( ContentRangeError, FileNotFound, @@ -21,10 +24,16 @@ from sanic.handlers import ContentRangeHandler from sanic.log import error_logger from sanic.models.futures import FutureRoute, FutureStatic +from sanic.models.handler_types import RouteHandler from sanic.response import HTTPResponse, file, file_stream from sanic.views import CompositionView +RouteWrapper = Callable[ + [RouteHandler], Union[RouteHandler, Tuple[Route, RouteHandler]] +] + + class RouteMixin: name: str @@ -55,7 +64,8 @@ def route( unquote: bool = False, static: bool = False, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Decorate a function to be registered as a route @@ -97,6 +107,7 @@ def decorator(handler): nonlocal websocket nonlocal static nonlocal version_prefix + nonlocal error_format if isinstance(handler, tuple): # if a handler fn is already wrapped in a route, the handler @@ -115,10 +126,16 @@ def decorator(handler): "Expected either string or Iterable of host strings, " "not %s" % host ) - - if isinstance(subprotocols, (list, tuple, set)): + if isinstance(subprotocols, list): + # Ordered subprotocols, maintain order + subprotocols = tuple(subprotocols) + elif isinstance(subprotocols, set): + # subprotocol is unordered, keep it unordered subprotocols = frozenset(subprotocols) + if not error_format or error_format == "auto": + error_format = self._determine_error_format(handler) + route = FutureRoute( handler, uri, @@ -134,6 +151,7 @@ def decorator(handler): unquote, static, version_prefix, + error_format, ) self._future_routes.add(route) @@ -168,7 +186,7 @@ def decorator(handler): def add_route( self, - handler, + handler: RouteHandler, uri: str, methods: Iterable[str] = frozenset({"GET"}), host: Optional[str] = None, @@ -177,7 +195,8 @@ def add_route( name: Optional[str] = None, stream: bool = False, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteHandler: """A helper method to register class instance or functions as a handler to the application url routes. @@ -200,7 +219,8 @@ def add_route( methods = set() for method in HTTP_METHODS: - _handler = getattr(handler.view_class, method.lower(), None) + view_class = getattr(handler, "view_class") + _handler = getattr(view_class, method.lower(), None) if _handler: methods.add(method) if hasattr(_handler, "is_stream"): @@ -226,6 +246,7 @@ def add_route( version=version, name=name, version_prefix=version_prefix, + error_format=error_format, )(handler) return handler @@ -239,7 +260,8 @@ def get( name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **GET** *HTTP* method @@ -262,6 +284,7 @@ def get( name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def post( @@ -273,7 +296,8 @@ def post( version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **POST** *HTTP* method @@ -296,6 +320,7 @@ def post( version=version, name=name, version_prefix=version_prefix, + error_format=error_format, ) def put( @@ -307,7 +332,8 @@ def put( version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **PUT** *HTTP* method @@ -330,6 +356,7 @@ def put( version=version, name=name, version_prefix=version_prefix, + error_format=error_format, ) def head( @@ -341,7 +368,8 @@ def head( name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **HEAD** *HTTP* method @@ -372,6 +400,7 @@ def head( name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def options( @@ -383,7 +412,8 @@ def options( name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **OPTIONS** *HTTP* method @@ -414,6 +444,7 @@ def options( name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def patch( @@ -425,7 +456,8 @@ def patch( version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **PATCH** *HTTP* method @@ -458,6 +490,7 @@ def patch( version=version, name=name, version_prefix=version_prefix, + error_format=error_format, ) def delete( @@ -469,7 +502,8 @@ def delete( name: Optional[str] = None, ignore_body: bool = True, version_prefix: str = "/v", - ): + error_format: Optional[str] = None, + ) -> RouteWrapper: """ Add an API URL under the **DELETE** *HTTP* method @@ -492,6 +526,7 @@ def delete( name=name, ignore_body=ignore_body, version_prefix=version_prefix, + error_format=error_format, ) def websocket( @@ -504,6 +539,7 @@ def websocket( name: Optional[str] = None, apply: bool = True, version_prefix: str = "/v", + error_format: Optional[str] = None, ): """ Decorate a function to be registered as a websocket route @@ -530,6 +566,7 @@ def websocket( subprotocols=subprotocols, websocket=True, version_prefix=version_prefix, + error_format=error_format, ) def add_websocket_route( @@ -542,6 +579,7 @@ def add_websocket_route( version: Optional[int] = None, name: Optional[str] = None, version_prefix: str = "/v", + error_format: Optional[str] = None, ): """ A helper method to register a function as a websocket route. @@ -570,6 +608,7 @@ def add_websocket_route( version=version, name=name, version_prefix=version_prefix, + error_format=error_format, )(handler) def static( @@ -585,6 +624,7 @@ def static( strict_slashes=None, content_type=None, apply=True, + resource_type=None, ): """ Register a root to serve files from. The input can either be a @@ -634,6 +674,7 @@ def static( host, strict_slashes, content_type, + resource_type, ) self._future_statics.add(static) @@ -777,10 +818,11 @@ async def _static_request_handler( ) except Exception: error_logger.exception( - f"Exception in static request handler:\ - path={file_or_directory}, " + f"Exception in static request handler: " + f"path={file_or_directory}, " f"relative_url={__file_uri__}" ) + raise def _register_static( self, @@ -828,8 +870,27 @@ def _register_static( name = static.name # If we're not trying to match a file directly, # serve from the folder - if not path.isfile(file_or_directory): + if not static.resource_type: + if not path.isfile(file_or_directory): + uri += "/<__file_uri__:path>" + elif static.resource_type == "dir": + if path.isfile(file_or_directory): + raise TypeError( + "Resource type improperly identified as directory. " + f"'{file_or_directory}'" + ) uri += "/<__file_uri__:path>" + elif static.resource_type == "file" and not path.isfile( + file_or_directory + ): + raise TypeError( + "Resource type improperly identified as file. " + f"'{file_or_directory}'" + ) + elif static.resource_type != "file": + raise ValueError( + "The resource_type should be set to 'file' or 'dir'" + ) # special prefix for static files # if not static.name.startswith("_static_"): @@ -846,7 +907,7 @@ def _register_static( ) ) - route, _ = self.route( + route, _ = self.route( # type: ignore uri=uri, methods=["GET", "HEAD"], name=name, @@ -856,3 +917,43 @@ def _register_static( )(_handler) return route + + def _determine_error_format(self, handler) -> str: + if not isinstance(handler, CompositionView): + try: + src = dedent(getsource(handler)) + tree = parse(src) + http_response_types = self._get_response_types(tree) + + if len(http_response_types) == 1: + return next(iter(http_response_types)) + except (OSError, TypeError): + ... + + return "auto" + + def _get_response_types(self, node): + types = set() + + class HttpResponseVisitor(NodeVisitor): + def visit_Return(self, node: Return) -> Any: + nonlocal types + + try: + checks = [node.value.func.id] # type: ignore + if node.value.keywords: # type: ignore + checks += [ + k.value + for k in node.value.keywords # type: ignore + if k.arg == "content_type" + ] + + for check in checks: + if check in RESPONSE_MAPPING: + types.add(RESPONSE_MAPPING[check]) + except AttributeError: + ... + + HttpResponseVisitor().visit(node) + + return types diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index e849e562ad..2be9fee2e6 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -23,7 +23,7 @@ def signal( *, apply: bool = True, condition: Dict[str, Any] = None, - ) -> Callable[[SignalHandler], FutureSignal]: + ) -> Callable[[SignalHandler], SignalHandler]: """ For creating a signal handler, used similar to a route handler: @@ -54,7 +54,7 @@ def decorator(handler: SignalHandler): if apply: self._apply_signal(future_signal) - return future_signal + return handler return decorator diff --git a/sanic/models/asgi.py b/sanic/models/asgi.py index 595b05532a..1b707ebc03 100644 --- a/sanic/models/asgi.py +++ b/sanic/models/asgi.py @@ -3,7 +3,7 @@ from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union from sanic.exceptions import InvalidUsage -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection ASGIScope = MutableMapping[str, Any] diff --git a/sanic/models/futures.py b/sanic/models/futures.py index 2350bedb9a..fe7d77ebcc 100644 --- a/sanic/models/futures.py +++ b/sanic/models/futures.py @@ -24,6 +24,7 @@ class FutureRoute(NamedTuple): unquote: bool static: bool version_prefix: str + error_format: Optional[str] class FutureListener(NamedTuple): @@ -52,6 +53,7 @@ class FutureStatic(NamedTuple): host: Optional[str] strict_slashes: Optional[bool] content_type: Optional[bool] + resource_type: Optional[str] class FutureSignal(NamedTuple): diff --git a/sanic/models/handler_types.py b/sanic/models/handler_types.py index 704def7a35..0144c964d8 100644 --- a/sanic/models/handler_types.py +++ b/sanic/models/handler_types.py @@ -21,5 +21,5 @@ ListenerType = Callable[ [Sanic, AbstractEventLoop], Optional[Coroutine[Any, Any, None]] ] -RouteHandler = Callable[..., Coroutine[Any, Any, HTTPResponse]] +RouteHandler = Callable[..., Coroutine[Any, Any, Optional[HTTPResponse]]] SignalHandler = Callable[..., Coroutine[Any, Any, None]] diff --git a/sanic/models/server_types.py b/sanic/models/server_types.py new file mode 100644 index 0000000000..f0ced2475c --- /dev/null +++ b/sanic/models/server_types.py @@ -0,0 +1,52 @@ +from types import SimpleNamespace + +from sanic.models.protocol_types import TransportProtocol + + +class Signal: + stopped = False + + +class ConnInfo: + """ + Local and remote addresses and SSL status info. + """ + + __slots__ = ( + "client_port", + "client", + "client_ip", + "ctx", + "peername", + "server_port", + "server", + "sockname", + "ssl", + ) + + def __init__(self, transport: TransportProtocol, unix=None): + self.ctx = SimpleNamespace() + self.peername = None + self.server = self.client = "" + self.server_port = self.client_port = 0 + self.client_ip = "" + self.sockname = addr = transport.get_extra_info("sockname") + self.ssl: bool = bool(transport.get_extra_info("sslcontext")) + + if isinstance(addr, str): # UNIX socket + self.server = unix or addr + return + + # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) + if isinstance(addr, tuple): + self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" + self.server_port = addr[1] + # self.server gets non-standard port appended + if addr[1] != (443 if self.ssl else 80): + self.server = f"{self.server}:{addr[1]}" + self.peername = addr = transport.get_extra_info("peername") + + if isinstance(addr, tuple): + self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" + self.client_ip = addr[0] + self.client_port = addr[1] diff --git a/sanic/request.py b/sanic/request.py index 177df637e6..c744e3c327 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -34,7 +34,9 @@ from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.exceptions import InvalidUsage from sanic.headers import ( + AcceptContainer, Options, + parse_accept, parse_content_header, parse_forwarded, parse_host, @@ -94,6 +96,7 @@ class Request: "head", "headers", "method", + "parsed_accept", "parsed_args", "parsed_not_grouped_args", "parsed_files", @@ -136,6 +139,7 @@ def __init__( self.conn_info: Optional[ConnInfo] = None self.ctx = SimpleNamespace() self.parsed_forwarded: Optional[Options] = None + self.parsed_accept: Optional[AcceptContainer] = None self.parsed_json = None self.parsed_form = None self.parsed_files = None @@ -296,6 +300,13 @@ def load_json(self, loads=json_loads): return self.parsed_json + @property + def accept(self) -> AcceptContainer: + if self.parsed_accept is None: + accept_header = self.headers.getone("accept", "") + self.parsed_accept = parse_accept(accept_header) + return self.parsed_accept + @property def token(self): """Attempt to return the auth header token. @@ -497,6 +508,10 @@ def match_info(self): """ return self._match_info + @match_info.setter + def match_info(self, value): + self._match_info = value + # Transport properties (obtained from local interface only) @property diff --git a/sanic/router.py b/sanic/router.py index 0973a3faa0..6995ed6da4 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -1,5 +1,9 @@ +from __future__ import annotations + from functools import lru_cache +from inspect import signature from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from uuid import UUID from sanic_routing import BaseRouter # type: ignore from sanic_routing.exceptions import NoMethod # type: ignore @@ -9,6 +13,7 @@ from sanic_routing.route import Route # type: ignore from sanic.constants import HTTP_METHODS +from sanic.errorpages import check_error_format from sanic.exceptions import MethodNotSupported, NotFound, SanicException from sanic.models.handler_types import RouteHandler @@ -74,6 +79,7 @@ def add( # type: ignore unquote: bool = False, static: bool = False, version_prefix: str = "/v", + error_format: Optional[str] = None, ) -> Union[Route, List[Route]]: """ Add a handler to the router @@ -106,6 +112,8 @@ def add( # type: ignore version = str(version).strip("/").lstrip("v") uri = "/".join([f"{version_prefix}{version}", uri.lstrip("/")]) + uri = self._normalize(uri, handler) + params = dict( path=uri, handler=handler, @@ -131,6 +139,11 @@ def add( # type: ignore route.ctx.stream = stream route.ctx.hosts = hosts route.ctx.static = static + route.ctx.error_format = ( + error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT + ) + + check_error_format(route.ctx.error_format) routes.append(route) @@ -187,3 +200,24 @@ def finalize(self, *args, **kwargs): raise SanicException( f"Invalid route: {route}. Parameter names cannot use '__'." ) + + def _normalize(self, uri: str, handler: RouteHandler) -> str: + if "<" not in uri: + return uri + + sig = signature(handler) + mapping = { + param.name: param.annotation.__name__.lower() + for param in sig.parameters.values() + if param.annotation in (str, int, float, UUID) + } + + reconstruction = [] + for part in uri.split("/"): + if part.startswith("<") and ":" not in part: + name = part[1:-1] + annotation = mapping.get(name) + if annotation: + part = f"<{name}:{annotation}>" + reconstruction.append(part) + return "/".join(reconstruction) diff --git a/sanic/server.py b/sanic/server.py deleted file mode 100644 index 4ec83f9c18..0000000000 --- a/sanic/server.py +++ /dev/null @@ -1,793 +0,0 @@ -from __future__ import annotations - -from ssl import SSLContext -from types import SimpleNamespace -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - Optional, - Type, - Union, -) - -from sanic.models.handler_types import ListenerType - - -if TYPE_CHECKING: - from sanic.app import Sanic - -import asyncio -import multiprocessing -import os -import secrets -import socket -import stat - -from asyncio import CancelledError -from asyncio.transports import Transport -from functools import partial -from inspect import isawaitable -from ipaddress import ip_address -from signal import SIG_IGN, SIGINT, SIGTERM, Signals -from signal import signal as signal_func -from time import monotonic as current_time - -from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows -from sanic.config import Config -from sanic.exceptions import RequestTimeout, ServiceUnavailable -from sanic.http import Http, Stage -from sanic.log import error_logger, logger -from sanic.models.protocol_types import TransportProtocol -from sanic.request import Request - - -try: - import uvloop # type: ignore - - if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -except ImportError: - pass - - -class Signal: - stopped = False - - -class ConnInfo: - """ - Local and remote addresses and SSL status info. - """ - - __slots__ = ( - "client_port", - "client", - "client_ip", - "ctx", - "peername", - "server_port", - "server", - "sockname", - "ssl", - ) - - def __init__(self, transport: TransportProtocol, unix=None): - self.ctx = SimpleNamespace() - self.peername = None - self.server = self.client = "" - self.server_port = self.client_port = 0 - self.client_ip = "" - self.sockname = addr = transport.get_extra_info("sockname") - self.ssl: bool = bool(transport.get_extra_info("sslcontext")) - - if isinstance(addr, str): # UNIX socket - self.server = unix or addr - return - - # IPv4 (ip, port) or IPv6 (ip, port, flowinfo, scopeid) - if isinstance(addr, tuple): - self.server = addr[0] if len(addr) == 2 else f"[{addr[0]}]" - self.server_port = addr[1] - # self.server gets non-standard port appended - if addr[1] != (443 if self.ssl else 80): - self.server = f"{self.server}:{addr[1]}" - self.peername = addr = transport.get_extra_info("peername") - - if isinstance(addr, tuple): - self.client = addr[0] if len(addr) == 2 else f"[{addr[0]}]" - self.client_ip = addr[0] - self.client_port = addr[1] - - -class HttpProtocol(asyncio.Protocol): - """ - This class provides a basic HTTP implementation of the sanic framework. - """ - - __slots__ = ( - # app - "app", - # event loop, connection - "loop", - "transport", - "connections", - "signal", - "conn_info", - "ctx", - # request params - "request", - # request config - "request_handler", - "request_timeout", - "response_timeout", - "keep_alive_timeout", - "request_max_size", - "request_class", - "error_handler", - # enable or disable access log purpose - "access_log", - # connection management - "state", - "url", - "_handler_task", - "_can_write", - "_data_received", - "_time", - "_task", - "_http", - "_exception", - "recv_buffer", - "_unix", - ) - - def __init__( - self, - *, - loop, - app: Sanic, - signal=None, - connections=None, - state=None, - unix=None, - **kwargs, - ): - asyncio.set_event_loop(loop) - self.loop = loop - self.app: Sanic = app - self.url = None - self.transport: Optional[Transport] = None - self.conn_info: Optional[ConnInfo] = None - self.request: Optional[Request] = None - self.signal = signal or Signal() - self.access_log = self.app.config.ACCESS_LOG - self.connections = connections if connections is not None else set() - self.request_handler = self.app.handle_request - self.error_handler = self.app.error_handler - self.request_timeout = self.app.config.REQUEST_TIMEOUT - self.response_timeout = self.app.config.RESPONSE_TIMEOUT - self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT - self.request_max_size = self.app.config.REQUEST_MAX_SIZE - self.request_class = self.app.request_class or Request - self.state = state if state else {} - if "requests_count" not in self.state: - self.state["requests_count"] = 0 - self._data_received = asyncio.Event() - self._can_write = asyncio.Event() - self._can_write.set() - self._exception = None - self._unix = unix - - def _setup_connection(self): - self._http = Http(self) - self._time = current_time() - self.check_timeouts() - - async def connection_task(self): - """ - Run a HTTP connection. - - Timeouts and some additional error handling occur here, while most of - everything else happens in class Http or in code called from there. - """ - try: - self._setup_connection() - await self._http.http1() - except CancelledError: - pass - except Exception: - error_logger.exception("protocol.connection_task uncaught") - finally: - if self.app.debug and self._http: - ip = self.transport.get_extra_info("peername") - error_logger.error( - "Connection lost before response written" - f" @ {ip} {self._http.request}" - ) - self._http = None - self._task = None - try: - self.close() - except BaseException: - error_logger.exception("Closing failed") - - async def receive_more(self): - """ - Wait until more data is received into the Server protocol's buffer - """ - self.transport.resume_reading() - self._data_received.clear() - await self._data_received.wait() - - def check_timeouts(self): - """ - Runs itself periodically to enforce any expired timeouts. - """ - try: - if not self._task: - return - duration = current_time() - self._time - stage = self._http.stage - if stage is Stage.IDLE and duration > self.keep_alive_timeout: - logger.debug("KeepAlive Timeout. Closing connection.") - elif stage is Stage.REQUEST and duration > self.request_timeout: - logger.debug("Request Timeout. Closing connection.") - self._http.exception = RequestTimeout("Request Timeout") - elif stage is Stage.HANDLER and self._http.upgrade_websocket: - logger.debug("Handling websocket. Timeouts disabled.") - return - elif ( - stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED) - and duration > self.response_timeout - ): - logger.debug("Response Timeout. Closing connection.") - self._http.exception = ServiceUnavailable("Response Timeout") - else: - interval = ( - min( - self.keep_alive_timeout, - self.request_timeout, - self.response_timeout, - ) - / 2 - ) - self.loop.call_later(max(0.1, interval), self.check_timeouts) - return - self._task.cancel() - except Exception: - error_logger.exception("protocol.check_timeouts") - - async def send(self, data): - """ - Writes data with backpressure control. - """ - await self._can_write.wait() - if self.transport.is_closing(): - raise CancelledError - self.transport.write(data) - self._time = current_time() - - def close_if_idle(self) -> bool: - """ - Close the connection if a request is not being sent or received - - :return: boolean - True if closed, false if staying open - """ - if self._http is None or self._http.stage is Stage.IDLE: - self.close() - return True - return False - - def close(self): - """ - Force close the connection. - """ - # Cause a call to connection_lost where further cleanup occurs - if self.transport: - self.transport.close() - self.transport = None - - # -------------------------------------------- # - # Only asyncio.Protocol callbacks below this - # -------------------------------------------- # - - def connection_made(self, transport): - try: - # TODO: Benchmark to find suitable write buffer limits - transport.set_write_buffer_limits(low=16384, high=65536) - self.connections.add(self) - self.transport = transport - self._task = self.loop.create_task(self.connection_task()) - self.recv_buffer = bytearray() - self.conn_info = ConnInfo(self.transport, unix=self._unix) - except Exception: - error_logger.exception("protocol.connect_made") - - def connection_lost(self, exc): - try: - self.connections.discard(self) - self.resume_writing() - if self._task: - self._task.cancel() - except Exception: - error_logger.exception("protocol.connection_lost") - - def pause_writing(self): - self._can_write.clear() - - def resume_writing(self): - self._can_write.set() - - def data_received(self, data: bytes): - try: - self._time = current_time() - if not data: - return self.close() - self.recv_buffer += data - - if ( - len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE - and self.transport - ): - self.transport.pause_reading() - - if self._data_received: - self._data_received.set() - except Exception: - error_logger.exception("protocol.data_received") - - -def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop): - """ - Trigger event callbacks (functions or async) - - :param events: one or more sync or async functions to execute - :param loop: event loop - """ - if events: - for event in events: - result = event(loop) - if isawaitable(result): - loop.run_until_complete(result) - - -class AsyncioServer: - """ - Wraps an asyncio server with functionality that might be useful to - a user who needs to manage the server lifecycle manually. - """ - - __slots__ = ( - "loop", - "serve_coro", - "_after_start", - "_before_stop", - "_after_stop", - "server", - "connections", - ) - - def __init__( - self, - loop, - serve_coro, - connections, - after_start: Optional[Iterable[ListenerType]], - before_stop: Optional[Iterable[ListenerType]], - after_stop: Optional[Iterable[ListenerType]], - ): - # Note, Sanic already called "before_server_start" events - # before this helper was even created. So we don't need it here. - self.loop = loop - self.serve_coro = serve_coro - self._after_start = after_start - self._before_stop = before_stop - self._after_stop = after_stop - self.server = None - self.connections = connections - - def after_start(self): - """ - Trigger "after_server_start" events - """ - trigger_events(self._after_start, self.loop) - - def before_stop(self): - """ - Trigger "before_server_stop" events - """ - trigger_events(self._before_stop, self.loop) - - def after_stop(self): - """ - Trigger "after_server_stop" events - """ - trigger_events(self._after_stop, self.loop) - - def is_serving(self) -> bool: - if self.server: - return self.server.is_serving() - return False - - def wait_closed(self): - if self.server: - return self.server.wait_closed() - - def close(self): - if self.server: - self.server.close() - coro = self.wait_closed() - task = asyncio.ensure_future(coro, loop=self.loop) - return task - - def start_serving(self): - if self.server: - try: - return self.server.start_serving() - except AttributeError: - raise NotImplementedError( - "server.start_serving not available in this version " - "of asyncio or uvloop." - ) - - def serve_forever(self): - if self.server: - try: - return self.server.serve_forever() - except AttributeError: - raise NotImplementedError( - "server.serve_forever not available in this version " - "of asyncio or uvloop." - ) - - def __await__(self): - """ - Starts the asyncio server, returns AsyncServerCoro - """ - task = asyncio.ensure_future(self.serve_coro) - while not task.done(): - yield - self.server = task.result() - return self - - -def serve( - host, - port, - app, - before_start: Optional[Iterable[ListenerType]] = None, - after_start: Optional[Iterable[ListenerType]] = None, - before_stop: Optional[Iterable[ListenerType]] = None, - after_stop: Optional[Iterable[ListenerType]] = None, - ssl: Optional[SSLContext] = None, - sock: Optional[socket.socket] = None, - unix: Optional[str] = None, - reuse_port: bool = False, - loop=None, - protocol: Type[asyncio.Protocol] = HttpProtocol, - backlog: int = 100, - register_sys_signals: bool = True, - run_multiple: bool = False, - run_async: bool = False, - connections=None, - signal=Signal(), - state=None, - asyncio_server_kwargs=None, -): - """Start asynchronous HTTP Server on an individual process. - - :param host: Address to host on - :param port: Port to host on - :param before_start: function to be executed before the server starts - listening. Takes arguments `app` instance and `loop` - :param after_start: function to be executed after the server starts - listening. Takes arguments `app` instance and `loop` - :param before_stop: function to be executed when a stop signal is - received before it is respected. Takes arguments - `app` instance and `loop` - :param after_stop: function to be executed when a stop signal is - received after it is respected. Takes arguments - `app` instance and `loop` - :param ssl: SSLContext - :param sock: Socket for the server to accept connections from - :param unix: Unix socket to listen on instead of TCP port - :param reuse_port: `True` for multiple workers - :param loop: asyncio compatible event loop - :param run_async: bool: Do not create a new event loop for the server, - and return an AsyncServer object rather than running it - :param asyncio_server_kwargs: key-value args for asyncio/uvloop - create_server method - :return: Nothing - """ - if not run_async and not loop: - # create new event_loop after fork - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - if app.debug: - loop.set_debug(app.debug) - - app.asgi = False - - connections = connections if connections is not None else set() - protocol_kwargs = _build_protocol_kwargs(protocol, app.config) - server = partial( - protocol, - loop=loop, - connections=connections, - signal=signal, - app=app, - state=state, - unix=unix, - **protocol_kwargs, - ) - asyncio_server_kwargs = ( - asyncio_server_kwargs if asyncio_server_kwargs else {} - ) - # UNIX sockets are always bound by us (to preserve semantics between modes) - if unix: - sock = bind_unix_socket(unix, backlog=backlog) - server_coroutine = loop.create_server( - server, - None if sock else host, - None if sock else port, - ssl=ssl, - reuse_port=reuse_port, - sock=sock, - backlog=backlog, - **asyncio_server_kwargs, - ) - - if run_async: - return AsyncioServer( - loop=loop, - serve_coro=server_coroutine, - connections=connections, - after_start=after_start, - before_stop=before_stop, - after_stop=after_stop, - ) - - trigger_events(before_start, loop) - - try: - http_server = loop.run_until_complete(server_coroutine) - except BaseException: - error_logger.exception("Unable to start server") - return - - trigger_events(after_start, loop) - - # Ignore SIGINT when run_multiple - if run_multiple: - signal_func(SIGINT, SIG_IGN) - - # Register signals for graceful termination - if register_sys_signals: - if OS_IS_WINDOWS: - ctrlc_workaround_for_windows(app) - else: - for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: - loop.add_signal_handler(_signal, app.stop) - pid = os.getpid() - try: - logger.info("Starting worker [%s]", pid) - loop.run_forever() - finally: - logger.info("Stopping worker [%s]", pid) - - # Run the on_stop function if provided - trigger_events(before_stop, loop) - - # Wait for event loop to finish and all connections to drain - http_server.close() - loop.run_until_complete(http_server.wait_closed()) - - # Complete all tasks on the loop - signal.stopped = True - for connection in connections: - connection.close_if_idle() - - # Gracefully shutdown timeout. - # We should provide graceful_shutdown_timeout, - # instead of letting connection hangs forever. - # Let's roughly calcucate time. - graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT - start_shutdown: float = 0 - while connections and (start_shutdown < graceful): - loop.run_until_complete(asyncio.sleep(0.1)) - start_shutdown = start_shutdown + 0.1 - - # Force close non-idle connection after waiting for - # graceful_shutdown_timeout - coros = [] - for conn in connections: - if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) - else: - conn.close() - - _shutdown = asyncio.gather(*coros) - loop.run_until_complete(_shutdown) - - trigger_events(after_stop, loop) - - remove_unix_socket(unix) - - -def _build_protocol_kwargs( - protocol: Type[asyncio.Protocol], config: Config -) -> Dict[str, Union[int, float]]: - if hasattr(protocol, "websocket_handshake"): - return { - "websocket_max_size": config.WEBSOCKET_MAX_SIZE, - "websocket_max_queue": config.WEBSOCKET_MAX_QUEUE, - "websocket_read_limit": config.WEBSOCKET_READ_LIMIT, - "websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT, - "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, - "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, - } - return {} - - -def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: - """Create TCP server socket. - :param host: IPv4, IPv6 or hostname may be specified - :param port: TCP port number - :param backlog: Maximum number of connections to queue - :return: socket.socket object - """ - try: # IP address: family must be specified for IPv6 at least - ip = ip_address(host) - host = str(ip) - sock = socket.socket( - socket.AF_INET6 if ip.version == 6 else socket.AF_INET - ) - except ValueError: # Hostname, may become AF_INET or AF_INET6 - sock = socket.socket() - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((host, port)) - sock.listen(backlog) - return sock - - -def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: - """Create unix socket. - :param path: filesystem path - :param backlog: Maximum number of connections to queue - :return: socket.socket object - """ - """Open or atomically replace existing socket with zero downtime.""" - # Sanitise and pre-verify socket path - path = os.path.abspath(path) - folder = os.path.dirname(path) - if not os.path.isdir(folder): - raise FileNotFoundError(f"Socket folder does not exist: {folder}") - try: - if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): - raise FileExistsError(f"Existing file is not a socket: {path}") - except FileNotFoundError: - pass - # Create new socket with a random temporary name - tmp_path = f"{path}.{secrets.token_urlsafe()}" - sock = socket.socket(socket.AF_UNIX) - try: - # Critical section begins (filename races) - sock.bind(tmp_path) - try: - os.chmod(tmp_path, mode) - # Start listening before rename to avoid connection failures - sock.listen(backlog) - os.rename(tmp_path, path) - except: # noqa: E722 - try: - os.unlink(tmp_path) - finally: - raise - except: # noqa: E722 - try: - sock.close() - finally: - raise - return sock - - -def remove_unix_socket(path: Optional[str]) -> None: - """Remove dead unix socket during server exit.""" - if not path: - return - try: - if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): - # Is it actually dead (doesn't belong to a new server instance)? - with socket.socket(socket.AF_UNIX) as testsock: - try: - testsock.connect(path) - except ConnectionRefusedError: - os.unlink(path) - except FileNotFoundError: - pass - - -def serve_single(server_settings): - main_start = server_settings.pop("main_start", None) - main_stop = server_settings.pop("main_stop", None) - - if not server_settings.get("run_async"): - # create new event_loop after fork - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - server_settings["loop"] = loop - - trigger_events(main_start, server_settings["loop"]) - serve(**server_settings) - trigger_events(main_stop, server_settings["loop"]) - - server_settings["loop"].close() - - -def serve_multiple(server_settings, workers): - """Start multiple server processes simultaneously. Stop on interrupt - and terminate signals, and drain connections when complete. - - :param server_settings: kw arguments to be passed to the serve function - :param workers: number of workers to launch - :param stop_event: if provided, is used as a stop signal - :return: - """ - server_settings["reuse_port"] = True - server_settings["run_multiple"] = True - - main_start = server_settings.pop("main_start", None) - main_stop = server_settings.pop("main_stop", None) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - trigger_events(main_start, loop) - - # Create a listening socket or use the one in settings - sock = server_settings.get("sock") - unix = server_settings["unix"] - backlog = server_settings["backlog"] - if unix: - sock = bind_unix_socket(unix, backlog=backlog) - server_settings["unix"] = unix - if sock is None: - sock = bind_socket( - server_settings["host"], server_settings["port"], backlog=backlog - ) - sock.set_inheritable(True) - server_settings["sock"] = sock - server_settings["host"] = None - server_settings["port"] = None - - processes = [] - - def sig_handler(signal, frame): - logger.info("Received signal %s. Shutting down.", Signals(signal).name) - for process in processes: - os.kill(process.pid, SIGTERM) - - signal_func(SIGINT, lambda s, f: sig_handler(s, f)) - signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) - mp = multiprocessing.get_context("fork") - - for _ in range(workers): - process = mp.Process(target=serve, kwargs=server_settings) - process.daemon = True - process.start() - processes.append(process) - - for process in processes: - process.join() - - # the above processes will block this until they're stopped - for process in processes: - process.terminate() - - trigger_events(main_stop, loop) - - sock.close() - loop.close() - remove_unix_socket(unix) diff --git a/sanic/server/__init__.py b/sanic/server/__init__.py new file mode 100644 index 0000000000..8e26dcd021 --- /dev/null +++ b/sanic/server/__init__.py @@ -0,0 +1,26 @@ +import asyncio + +from sanic.models.server_types import ConnInfo, Signal +from sanic.server.async_server import AsyncioServer +from sanic.server.protocols.http_protocol import HttpProtocol +from sanic.server.runners import serve, serve_multiple, serve_single + + +try: + import uvloop # type: ignore + + if not isinstance(asyncio.get_event_loop_policy(), uvloop.EventLoopPolicy): + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + pass + + +__all__ = ( + "AsyncioServer", + "ConnInfo", + "HttpProtocol", + "Signal", + "serve", + "serve_multiple", + "serve_single", +) diff --git a/sanic/server/async_server.py b/sanic/server/async_server.py new file mode 100644 index 0000000000..33b8b4c0f9 --- /dev/null +++ b/sanic/server/async_server.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import asyncio + +from sanic.exceptions import SanicException + + +class AsyncioServer: + """ + Wraps an asyncio server with functionality that might be useful to + a user who needs to manage the server lifecycle manually. + """ + + __slots__ = ("app", "connections", "loop", "serve_coro", "server", "init") + + def __init__( + self, + app, + loop, + serve_coro, + connections, + ): + # Note, Sanic already called "before_server_start" events + # before this helper was even created. So we don't need it here. + self.app = app + self.connections = connections + self.loop = loop + self.serve_coro = serve_coro + self.server = None + self.init = False + + def startup(self): + """ + Trigger "before_server_start" events + """ + self.init = True + return self.app._startup() + + def before_start(self): + """ + Trigger "before_server_start" events + """ + return self._server_event("init", "before") + + def after_start(self): + """ + Trigger "after_server_start" events + """ + return self._server_event("init", "after") + + def before_stop(self): + """ + Trigger "before_server_stop" events + """ + return self._server_event("shutdown", "before") + + def after_stop(self): + """ + Trigger "after_server_stop" events + """ + return self._server_event("shutdown", "after") + + def is_serving(self) -> bool: + if self.server: + return self.server.is_serving() + return False + + def wait_closed(self): + if self.server: + return self.server.wait_closed() + + def close(self): + if self.server: + self.server.close() + coro = self.wait_closed() + task = asyncio.ensure_future(coro, loop=self.loop) + return task + + def start_serving(self): + if self.server: + try: + return self.server.start_serving() + except AttributeError: + raise NotImplementedError( + "server.start_serving not available in this version " + "of asyncio or uvloop." + ) + + def serve_forever(self): + if self.server: + try: + return self.server.serve_forever() + except AttributeError: + raise NotImplementedError( + "server.serve_forever not available in this version " + "of asyncio or uvloop." + ) + + def _server_event(self, concern: str, action: str): + if not self.init: + raise SanicException( + "Cannot dispatch server event without " + "first running server.startup()" + ) + return self.app._server_event(concern, action, loop=self.loop) + + def __await__(self): + """ + Starts the asyncio server, returns AsyncServerCoro + """ + task = asyncio.ensure_future(self.serve_coro) + while not task.done(): + yield + self.server = task.result() + return self diff --git a/sanic/server/events.py b/sanic/server/events.py new file mode 100644 index 0000000000..3b71281d9e --- /dev/null +++ b/sanic/server/events.py @@ -0,0 +1,16 @@ +from inspect import isawaitable +from typing import Any, Callable, Iterable, Optional + + +def trigger_events(events: Optional[Iterable[Callable[..., Any]]], loop): + """ + Trigger event callbacks (functions or async) + + :param events: one or more sync or async functions to execute + :param loop: event loop + """ + if events: + for event in events: + result = event(loop) + if isawaitable(result): + loop.run_until_complete(result) diff --git a/sanic/server/protocols/__init__.py b/sanic/server/protocols/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sanic/server/protocols/base_protocol.py b/sanic/server/protocols/base_protocol.py new file mode 100644 index 0000000000..63d4bfb5b7 --- /dev/null +++ b/sanic/server/protocols/base_protocol.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + + +if TYPE_CHECKING: + from sanic.app import Sanic + +import asyncio + +from asyncio import CancelledError +from asyncio.transports import Transport +from time import monotonic as current_time + +from sanic.log import error_logger +from sanic.models.server_types import ConnInfo, Signal + + +class SanicProtocol(asyncio.Protocol): + __slots__ = ( + "app", + # event loop, connection + "loop", + "transport", + "connections", + "conn_info", + "signal", + "_can_write", + "_time", + "_task", + "_unix", + "_data_received", + ) + + def __init__( + self, + *, + loop, + app: Sanic, + signal=None, + connections=None, + unix=None, + **kwargs, + ): + asyncio.set_event_loop(loop) + self.loop = loop + self.app: Sanic = app + self.signal = signal or Signal() + self.transport: Optional[Transport] = None + self.connections = connections if connections is not None else set() + self.conn_info: Optional[ConnInfo] = None + self._can_write = asyncio.Event() + self._can_write.set() + self._unix = unix + self._time = 0.0 # type: float + self._task = None # type: Optional[asyncio.Task] + self._data_received = asyncio.Event() + + @property + def ctx(self): + if self.conn_info is not None: + return self.conn_info.ctx + else: + return None + + async def send(self, data): + """ + Generic data write implementation with backpressure control. + """ + await self._can_write.wait() + if self.transport.is_closing(): + raise CancelledError + self.transport.write(data) + self._time = current_time() + + async def receive_more(self): + """ + Wait until more data is received into the Server protocol's buffer + """ + self.transport.resume_reading() + self._data_received.clear() + await self._data_received.wait() + + def close(self, timeout: Optional[float] = None): + """ + Attempt close the connection. + """ + # Cause a call to connection_lost where further cleanup occurs + if self.transport: + self.transport.close() + if timeout is None: + timeout = self.app.config.GRACEFUL_SHUTDOWN_TIMEOUT + self.loop.call_later(timeout, self.abort) + + def abort(self): + """ + Force close the connection. + """ + # Cause a call to connection_lost where further cleanup occurs + if self.transport: + self.transport.abort() + self.transport = None + + # asyncio.Protocol API Callbacks # + # ------------------------------ # + def connection_made(self, transport): + """ + Generic connection-made, with no connection_task, and no recv_buffer. + Override this for protocol-specific connection implementations. + """ + try: + transport.set_write_buffer_limits(low=16384, high=65536) + self.connections.add(self) + self.transport = transport + self.conn_info = ConnInfo(self.transport, unix=self._unix) + except Exception: + error_logger.exception("protocol.connect_made") + + def connection_lost(self, exc): + try: + self.connections.discard(self) + self.resume_writing() + if self._task: + self._task.cancel() + except BaseException: + error_logger.exception("protocol.connection_lost") + + def pause_writing(self): + self._can_write.clear() + + def resume_writing(self): + self._can_write.set() + + def data_received(self, data: bytes): + try: + self._time = current_time() + if not data: + return self.close() + + if self._data_received: + self._data_received.set() + except BaseException: + error_logger.exception("protocol.data_received") diff --git a/sanic/server/protocols/http_protocol.py b/sanic/server/protocols/http_protocol.py new file mode 100644 index 0000000000..409f5e4b2f --- /dev/null +++ b/sanic/server/protocols/http_protocol.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from sanic.touchup.meta import TouchUpMeta + + +if TYPE_CHECKING: + from sanic.app import Sanic + +from asyncio import CancelledError +from time import monotonic as current_time + +from sanic.exceptions import RequestTimeout, ServiceUnavailable +from sanic.http import Http, Stage +from sanic.log import error_logger, logger +from sanic.models.server_types import ConnInfo +from sanic.request import Request +from sanic.server.protocols.base_protocol import SanicProtocol + + +class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta): + """ + This class provides implements the HTTP 1.1 protocol on top of our + Sanic Server transport + """ + + __touchup__ = ( + "send", + "connection_task", + ) + __slots__ = ( + # request params + "request", + # request config + "request_handler", + "request_timeout", + "response_timeout", + "keep_alive_timeout", + "request_max_size", + "request_class", + "error_handler", + # enable or disable access log purpose + "access_log", + # connection management + "state", + "url", + "_handler_task", + "_http", + "_exception", + "recv_buffer", + ) + + def __init__( + self, + *, + loop, + app: Sanic, + signal=None, + connections=None, + state=None, + unix=None, + **kwargs, + ): + super().__init__( + loop=loop, + app=app, + signal=signal, + connections=connections, + unix=unix, + ) + self.url = None + self.request: Optional[Request] = None + self.access_log = self.app.config.ACCESS_LOG + self.request_handler = self.app.handle_request + self.error_handler = self.app.error_handler + self.request_timeout = self.app.config.REQUEST_TIMEOUT + self.response_timeout = self.app.config.RESPONSE_TIMEOUT + self.keep_alive_timeout = self.app.config.KEEP_ALIVE_TIMEOUT + self.request_max_size = self.app.config.REQUEST_MAX_SIZE + self.request_class = self.app.request_class or Request + self.state = state if state else {} + if "requests_count" not in self.state: + self.state["requests_count"] = 0 + self._exception = None + + def _setup_connection(self): + self._http = Http(self) + self._time = current_time() + self.check_timeouts() + + async def connection_task(self): # no cov + """ + Run a HTTP connection. + + Timeouts and some additional error handling occur here, while most of + everything else happens in class Http or in code called from there. + """ + try: + self._setup_connection() + await self.app.dispatch( + "http.lifecycle.begin", + inline=True, + context={"conn_info": self.conn_info}, + ) + await self._http.http1() + except CancelledError: + pass + except Exception: + error_logger.exception("protocol.connection_task uncaught") + finally: + if ( + self.app.debug + and self._http + and self.transport + and not self._http.upgrade_websocket + ): + ip = self.transport.get_extra_info("peername") + error_logger.error( + "Connection lost before response written" + f" @ {ip} {self._http.request}" + ) + self._http = None + self._task = None + try: + self.close() + except BaseException: + error_logger.exception("Closing failed") + finally: + await self.app.dispatch( + "http.lifecycle.complete", + inline=True, + context={"conn_info": self.conn_info}, + ) + # Important to keep this Ellipsis here for the TouchUp module + ... + + def check_timeouts(self): + """ + Runs itself periodically to enforce any expired timeouts. + """ + try: + if not self._task: + return + duration = current_time() - self._time + stage = self._http.stage + if stage is Stage.IDLE and duration > self.keep_alive_timeout: + logger.debug("KeepAlive Timeout. Closing connection.") + elif stage is Stage.REQUEST and duration > self.request_timeout: + logger.debug("Request Timeout. Closing connection.") + self._http.exception = RequestTimeout("Request Timeout") + elif stage is Stage.HANDLER and self._http.upgrade_websocket: + logger.debug("Handling websocket. Timeouts disabled.") + return + elif ( + stage in (Stage.HANDLER, Stage.RESPONSE, Stage.FAILED) + and duration > self.response_timeout + ): + logger.debug("Response Timeout. Closing connection.") + self._http.exception = ServiceUnavailable("Response Timeout") + else: + interval = ( + min( + self.keep_alive_timeout, + self.request_timeout, + self.response_timeout, + ) + / 2 + ) + self.loop.call_later(max(0.1, interval), self.check_timeouts) + return + self._task.cancel() + except Exception: + error_logger.exception("protocol.check_timeouts") + + async def send(self, data): # no cov + """ + Writes HTTP data with backpressure control. + """ + await self._can_write.wait() + if self.transport.is_closing(): + raise CancelledError + await self.app.dispatch( + "http.lifecycle.send", + inline=True, + context={"data": data}, + ) + self.transport.write(data) + self._time = current_time() + + def close_if_idle(self) -> bool: + """ + Close the connection if a request is not being sent or received + + :return: boolean - True if closed, false if staying open + """ + if self._http is None or self._http.stage is Stage.IDLE: + self.close() + return True + return False + + # -------------------------------------------- # + # Only asyncio.Protocol callbacks below this + # -------------------------------------------- # + + def connection_made(self, transport): + """ + HTTP-protocol-specific new connection handler + """ + try: + # TODO: Benchmark to find suitable write buffer limits + transport.set_write_buffer_limits(low=16384, high=65536) + self.connections.add(self) + self.transport = transport + self._task = self.loop.create_task(self.connection_task()) + self.recv_buffer = bytearray() + self.conn_info = ConnInfo(self.transport, unix=self._unix) + except Exception: + error_logger.exception("protocol.connect_made") + + def data_received(self, data: bytes): + + try: + self._time = current_time() + if not data: + return self.close() + self.recv_buffer += data + + if ( + len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE + and self.transport + ): + self.transport.pause_reading() + + if self._data_received: + self._data_received.set() + except Exception: + error_logger.exception("protocol.data_received") diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py new file mode 100644 index 0000000000..457f1cd065 --- /dev/null +++ b/sanic/server/protocols/websocket_protocol.py @@ -0,0 +1,164 @@ +from typing import TYPE_CHECKING, Optional, Sequence + +from websockets.connection import CLOSED, CLOSING, OPEN +from websockets.server import ServerConnection + +from sanic.exceptions import ServerError +from sanic.log import error_logger +from sanic.server import HttpProtocol + +from ..websockets.impl import WebsocketImplProtocol + + +if TYPE_CHECKING: + from websockets import http11 + + +class WebSocketProtocol(HttpProtocol): + + websocket: Optional[WebsocketImplProtocol] + websocket_timeout: float + websocket_max_size = Optional[int] + websocket_ping_interval = Optional[float] + websocket_ping_timeout = Optional[float] + + def __init__( + self, + *args, + websocket_timeout: float = 10.0, + websocket_max_size: Optional[int] = None, + websocket_max_queue: Optional[int] = None, # max_queue is deprecated + websocket_read_limit: Optional[int] = None, # read_limit is deprecated + websocket_write_limit: Optional[int] = None, # write_limit deprecated + websocket_ping_interval: Optional[float] = 20.0, + websocket_ping_timeout: Optional[float] = 20.0, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.websocket = None + self.websocket_timeout = websocket_timeout + self.websocket_max_size = websocket_max_size + if websocket_max_queue is not None and websocket_max_queue > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses queueing, so websocket_max_queue" + " is no longer required." + ) + ) + if websocket_read_limit is not None and websocket_read_limit > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses read buffers, so " + "websocket_read_limit is not required." + ) + ) + if websocket_write_limit is not None and websocket_write_limit > 0: + # TODO: Reminder remove this warning in v22.3 + error_logger.warning( + DeprecationWarning( + "Websocket no longer uses write buffers, so " + "websocket_write_limit is not required." + ) + ) + self.websocket_ping_interval = websocket_ping_interval + self.websocket_ping_timeout = websocket_ping_timeout + + def connection_lost(self, exc): + if self.websocket is not None: + self.websocket.connection_lost(exc) + super().connection_lost(exc) + + def data_received(self, data): + if self.websocket is not None: + self.websocket.data_received(data) + else: + # Pass it to HttpProtocol handler first + # That will (hopefully) upgrade it to a websocket. + super().data_received(data) + + def eof_received(self) -> Optional[bool]: + if self.websocket is not None: + return self.websocket.eof_received() + else: + return False + + def close(self, timeout: Optional[float] = None): + # Called by HttpProtocol at the end of connection_task + # If we've upgraded to websocket, we do our own closing + if self.websocket is not None: + # Note, we don't want to use websocket.close() + # That is used for user's application code to send a + # websocket close packet. This is different. + self.websocket.end_connection(1001) + else: + super().close() + + def close_if_idle(self): + # Called by Sanic Server when shutting down + # If we've upgraded to websocket, shut it down + if self.websocket is not None: + if self.websocket.connection.state in (CLOSING, CLOSED): + return True + elif self.websocket.loop is not None: + self.websocket.loop.create_task(self.websocket.close(1001)) + else: + self.websocket.end_connection(1001) + else: + return super().close_if_idle() + + async def websocket_handshake( + self, request, subprotocols=Optional[Sequence[str]] + ): + # let the websockets package do the handshake with the client + try: + if subprotocols is not None: + # subprotocols can be a set or frozenset, + # but ServerConnection needs a list + subprotocols = list(subprotocols) + ws_conn = ServerConnection( + max_size=self.websocket_max_size, + subprotocols=subprotocols, + state=OPEN, + logger=error_logger, + ) + resp: "http11.Response" = ws_conn.accept(request) + except Exception: + msg = ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ) + raise ServerError(msg, status_code=500) + if 100 <= resp.status_code <= 299: + rbody = "".join( + [ + "HTTP/1.1 ", + str(resp.status_code), + " ", + resp.reason_phrase, + "\r\n", + ] + ) + rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items()) + if resp.body is not None: + rbody += f"\r\n{resp.body}\r\n\r\n" + else: + rbody += "\r\n" + await super().send(rbody.encode()) + else: + raise ServerError(resp.body, resp.status_code) + self.websocket = WebsocketImplProtocol( + ws_conn, + ping_interval=self.websocket_ping_interval, + ping_timeout=self.websocket_ping_timeout, + close_timeout=self.websocket_timeout, + ) + loop = ( + request.transport.loop + if hasattr(request, "transport") + and hasattr(request.transport, "loop") + else None + ) + await self.websocket.connection_made(self, loop=loop) + return self.websocket diff --git a/sanic/server/runners.py b/sanic/server/runners.py new file mode 100644 index 0000000000..f0bebb030c --- /dev/null +++ b/sanic/server/runners.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from ssl import SSLContext +from typing import TYPE_CHECKING, Dict, Optional, Type, Union + +from sanic.config import Config +from sanic.server.events import trigger_events + + +if TYPE_CHECKING: + from sanic.app import Sanic + +import asyncio +import multiprocessing +import os +import socket + +from functools import partial +from signal import SIG_IGN, SIGINT, SIGTERM, Signals +from signal import signal as signal_func + +from sanic.compat import OS_IS_WINDOWS, ctrlc_workaround_for_windows +from sanic.log import error_logger, logger +from sanic.models.server_types import Signal +from sanic.server.async_server import AsyncioServer +from sanic.server.protocols.http_protocol import HttpProtocol +from sanic.server.socket import ( + bind_socket, + bind_unix_socket, + remove_unix_socket, +) + + +def serve( + host, + port, + app: Sanic, + ssl: Optional[SSLContext] = None, + sock: Optional[socket.socket] = None, + unix: Optional[str] = None, + reuse_port: bool = False, + loop=None, + protocol: Type[asyncio.Protocol] = HttpProtocol, + backlog: int = 100, + register_sys_signals: bool = True, + run_multiple: bool = False, + run_async: bool = False, + connections=None, + signal=Signal(), + state=None, + asyncio_server_kwargs=None, +): + """Start asynchronous HTTP Server on an individual process. + + :param host: Address to host on + :param port: Port to host on + :param before_start: function to be executed before the server starts + listening. Takes arguments `app` instance and `loop` + :param after_start: function to be executed after the server starts + listening. Takes arguments `app` instance and `loop` + :param before_stop: function to be executed when a stop signal is + received before it is respected. Takes arguments + `app` instance and `loop` + :param after_stop: function to be executed when a stop signal is + received after it is respected. Takes arguments + `app` instance and `loop` + :param ssl: SSLContext + :param sock: Socket for the server to accept connections from + :param unix: Unix socket to listen on instead of TCP port + :param reuse_port: `True` for multiple workers + :param loop: asyncio compatible event loop + :param run_async: bool: Do not create a new event loop for the server, + and return an AsyncServer object rather than running it + :param asyncio_server_kwargs: key-value args for asyncio/uvloop + create_server method + :return: Nothing + """ + if not run_async and not loop: + # create new event_loop after fork + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if app.debug: + loop.set_debug(app.debug) + + app.asgi = False + + connections = connections if connections is not None else set() + protocol_kwargs = _build_protocol_kwargs(protocol, app.config) + server = partial( + protocol, + loop=loop, + connections=connections, + signal=signal, + app=app, + state=state, + unix=unix, + **protocol_kwargs, + ) + asyncio_server_kwargs = ( + asyncio_server_kwargs if asyncio_server_kwargs else {} + ) + # UNIX sockets are always bound by us (to preserve semantics between modes) + if unix: + sock = bind_unix_socket(unix, backlog=backlog) + server_coroutine = loop.create_server( + server, + None if sock else host, + None if sock else port, + ssl=ssl, + reuse_port=reuse_port, + sock=sock, + backlog=backlog, + **asyncio_server_kwargs, + ) + + if run_async: + return AsyncioServer( + app=app, + loop=loop, + serve_coro=server_coroutine, + connections=connections, + ) + + loop.run_until_complete(app._startup()) + loop.run_until_complete(app._server_event("init", "before")) + + try: + http_server = loop.run_until_complete(server_coroutine) + except BaseException: + error_logger.exception("Unable to start server") + return + + # Ignore SIGINT when run_multiple + if run_multiple: + signal_func(SIGINT, SIG_IGN) + + # Register signals for graceful termination + if register_sys_signals: + if OS_IS_WINDOWS: + ctrlc_workaround_for_windows(app) + else: + for _signal in [SIGTERM] if run_multiple else [SIGINT, SIGTERM]: + loop.add_signal_handler(_signal, app.stop) + + loop.run_until_complete(app._server_event("init", "after")) + pid = os.getpid() + try: + logger.info("Starting worker [%s]", pid) + loop.run_forever() + finally: + logger.info("Stopping worker [%s]", pid) + + # Run the on_stop function if provided + loop.run_until_complete(app._server_event("shutdown", "before")) + + # Wait for event loop to finish and all connections to drain + http_server.close() + loop.run_until_complete(http_server.wait_closed()) + + # Complete all tasks on the loop + signal.stopped = True + for connection in connections: + connection.close_if_idle() + + # Gracefully shutdown timeout. + # We should provide graceful_shutdown_timeout, + # instead of letting connection hangs forever. + # Let's roughly calcucate time. + graceful = app.config.GRACEFUL_SHUTDOWN_TIMEOUT + start_shutdown: float = 0 + while connections and (start_shutdown < graceful): + loop.run_until_complete(asyncio.sleep(0.1)) + start_shutdown = start_shutdown + 0.1 + + # Force close non-idle connection after waiting for + # graceful_shutdown_timeout + for conn in connections: + if hasattr(conn, "websocket") and conn.websocket: + conn.websocket.fail_connection(code=1001) + else: + conn.abort() + loop.run_until_complete(app._server_event("shutdown", "after")) + + remove_unix_socket(unix) + + +def serve_single(server_settings): + main_start = server_settings.pop("main_start", None) + main_stop = server_settings.pop("main_stop", None) + + if not server_settings.get("run_async"): + # create new event_loop after fork + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + server_settings["loop"] = loop + + trigger_events(main_start, server_settings["loop"]) + serve(**server_settings) + trigger_events(main_stop, server_settings["loop"]) + + server_settings["loop"].close() + + +def serve_multiple(server_settings, workers): + """Start multiple server processes simultaneously. Stop on interrupt + and terminate signals, and drain connections when complete. + + :param server_settings: kw arguments to be passed to the serve function + :param workers: number of workers to launch + :param stop_event: if provided, is used as a stop signal + :return: + """ + server_settings["reuse_port"] = True + server_settings["run_multiple"] = True + + main_start = server_settings.pop("main_start", None) + main_stop = server_settings.pop("main_stop", None) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + trigger_events(main_start, loop) + + # Create a listening socket or use the one in settings + sock = server_settings.get("sock") + unix = server_settings["unix"] + backlog = server_settings["backlog"] + if unix: + sock = bind_unix_socket(unix, backlog=backlog) + server_settings["unix"] = unix + if sock is None: + sock = bind_socket( + server_settings["host"], server_settings["port"], backlog=backlog + ) + sock.set_inheritable(True) + server_settings["sock"] = sock + server_settings["host"] = None + server_settings["port"] = None + + processes = [] + + def sig_handler(signal, frame): + logger.info("Received signal %s. Shutting down.", Signals(signal).name) + for process in processes: + os.kill(process.pid, SIGTERM) + + signal_func(SIGINT, lambda s, f: sig_handler(s, f)) + signal_func(SIGTERM, lambda s, f: sig_handler(s, f)) + mp = multiprocessing.get_context("fork") + + for _ in range(workers): + process = mp.Process(target=serve, kwargs=server_settings) + process.daemon = True + process.start() + processes.append(process) + + for process in processes: + process.join() + + # the above processes will block this until they're stopped + for process in processes: + process.terminate() + + trigger_events(main_stop, loop) + + sock.close() + loop.close() + remove_unix_socket(unix) + + +def _build_protocol_kwargs( + protocol: Type[asyncio.Protocol], config: Config +) -> Dict[str, Union[int, float]]: + if hasattr(protocol, "websocket_handshake"): + return { + "websocket_max_size": config.WEBSOCKET_MAX_SIZE, + "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, + "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, + } + return {} diff --git a/sanic/server/socket.py b/sanic/server/socket.py new file mode 100644 index 0000000000..3d908306ca --- /dev/null +++ b/sanic/server/socket.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import os +import secrets +import socket +import stat + +from ipaddress import ip_address +from typing import Optional + + +def bind_socket(host: str, port: int, *, backlog=100) -> socket.socket: + """Create TCP server socket. + :param host: IPv4, IPv6 or hostname may be specified + :param port: TCP port number + :param backlog: Maximum number of connections to queue + :return: socket.socket object + """ + try: # IP address: family must be specified for IPv6 at least + ip = ip_address(host) + host = str(ip) + sock = socket.socket( + socket.AF_INET6 if ip.version == 6 else socket.AF_INET + ) + except ValueError: # Hostname, may become AF_INET or AF_INET6 + sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, port)) + sock.listen(backlog) + return sock + + +def bind_unix_socket(path: str, *, mode=0o666, backlog=100) -> socket.socket: + """Create unix socket. + :param path: filesystem path + :param backlog: Maximum number of connections to queue + :return: socket.socket object + """ + """Open or atomically replace existing socket with zero downtime.""" + # Sanitise and pre-verify socket path + path = os.path.abspath(path) + folder = os.path.dirname(path) + if not os.path.isdir(folder): + raise FileNotFoundError(f"Socket folder does not exist: {folder}") + try: + if not stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): + raise FileExistsError(f"Existing file is not a socket: {path}") + except FileNotFoundError: + pass + # Create new socket with a random temporary name + tmp_path = f"{path}.{secrets.token_urlsafe()}" + sock = socket.socket(socket.AF_UNIX) + try: + # Critical section begins (filename races) + sock.bind(tmp_path) + try: + os.chmod(tmp_path, mode) + # Start listening before rename to avoid connection failures + sock.listen(backlog) + os.rename(tmp_path, path) + except: # noqa: E722 + try: + os.unlink(tmp_path) + finally: + raise + except: # noqa: E722 + try: + sock.close() + finally: + raise + return sock + + +def remove_unix_socket(path: Optional[str]) -> None: + """Remove dead unix socket during server exit.""" + if not path: + return + try: + if stat.S_ISSOCK(os.stat(path, follow_symlinks=False).st_mode): + # Is it actually dead (doesn't belong to a new server instance)? + with socket.socket(socket.AF_UNIX) as testsock: + try: + testsock.connect(path) + except ConnectionRefusedError: + os.unlink(path) + except FileNotFoundError: + pass diff --git a/sanic/server/websockets/__init__.py b/sanic/server/websockets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sanic/server/websockets/connection.py b/sanic/server/websockets/connection.py new file mode 100644 index 0000000000..c53a65a58d --- /dev/null +++ b/sanic/server/websockets/connection.py @@ -0,0 +1,82 @@ +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + MutableMapping, + Optional, + Union, +) + + +ASIMessage = MutableMapping[str, Any] + + +class WebSocketConnection: + """ + This is for ASGI Connections. + It provides an interface similar to WebsocketProtocol, but + sends/receives over an ASGI connection. + """ + + # TODO + # - Implement ping/pong + + def __init__( + self, + send: Callable[[ASIMessage], Awaitable[None]], + receive: Callable[[], Awaitable[ASIMessage]], + subprotocols: Optional[List[str]] = None, + ) -> None: + self._send = send + self._receive = receive + self._subprotocols = subprotocols or [] + + async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: + message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} + + if isinstance(data, bytes): + message.update({"bytes": data}) + else: + message.update({"text": str(data)}) + + await self._send(message) + + async def recv(self, *args, **kwargs) -> Optional[str]: + message = await self._receive() + + if message["type"] == "websocket.receive": + return message["text"] + elif message["type"] == "websocket.disconnect": + pass + + return None + + receive = recv + + async def accept(self, subprotocols: Optional[List[str]] = None) -> None: + subprotocol = None + if subprotocols: + for subp in subprotocols: + if subp in self.subprotocols: + subprotocol = subp + break + + await self._send( + { + "type": "websocket.accept", + "subprotocol": subprotocol, + } + ) + + async def close(self, code: int = 1000, reason: str = "") -> None: + pass + + @property + def subprotocols(self): + return self._subprotocols + + @subprotocols.setter + def subprotocols(self, subprotocols: Optional[List[str]] = None): + self._subprotocols = subprotocols or [] diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py new file mode 100644 index 0000000000..fef27db11e --- /dev/null +++ b/sanic/server/websockets/frame.py @@ -0,0 +1,294 @@ +import asyncio +import codecs + +from typing import TYPE_CHECKING, AsyncIterator, List, Optional + +from websockets.frames import Frame, Opcode +from websockets.typing import Data + +from sanic.exceptions import ServerError + + +if TYPE_CHECKING: + from .impl import WebsocketImplProtocol + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + + +class WebsocketFrameAssembler: + """ + Assemble a message from frames. + Code borrowed from aaugustin/websockets project: + https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py + """ + + __slots__ = ( + "protocol", + "read_mutex", + "write_mutex", + "message_complete", + "message_fetched", + "get_in_progress", + "decoder", + "completed_queue", + "chunks", + "chunks_queue", + "paused", + "get_id", + "put_id", + ) + if TYPE_CHECKING: + protocol: "WebsocketImplProtocol" + read_mutex: asyncio.Lock + write_mutex: asyncio.Lock + message_complete: asyncio.Event + message_fetched: asyncio.Event + completed_queue: asyncio.Queue + get_in_progress: bool + decoder: Optional[codecs.IncrementalDecoder] + # For streaming chunks rather than messages: + chunks: List[Data] + chunks_queue: Optional[asyncio.Queue[Optional[Data]]] + paused: bool + + def __init__(self, protocol) -> None: + + self.protocol = protocol + + self.read_mutex = asyncio.Lock() + self.write_mutex = asyncio.Lock() + + self.completed_queue = asyncio.Queue( + maxsize=1 + ) # type: asyncio.Queue[Data] + + # put() sets this event to tell get() that a message can be fetched. + self.message_complete = asyncio.Event() + # get() sets this event to let put() + self.message_fetched = asyncio.Event() + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # Decoder for text frames, None for binary frames. + self.decoder = None + + # Buffer data from frames belonging to the same message. + self.chunks = [] + + # When switching from "buffering" to "streaming", we use a thread-safe + # queue for transferring frames from the writing thread (library code) + # to the reading thread (user code). We're buffering when chunks_queue + # is None and streaming when it's a Queue. None is a sentinel + # value marking the end of the stream, superseding message_complete. + + # Stream data from frames belonging to the same message. + self.chunks_queue = None + + # Flag to indicate we've paused the protocol + self.paused = False + + async def get(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Read the next message. + :meth:`get` returns a single :class:`str` or :class:`bytes`. + If the :message was fragmented, :meth:`get` waits until the last frame + is received, then it reassembles the message. + If ``timeout`` is set and elapses before a complete message is + received, :meth:`get` returns ``None``. + """ + async with self.read_mutex: + if timeout is not None and timeout <= 0: + if not self.message_complete.is_set(): + return None + if self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe + raise ServerError( + "Called get() on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + self.get_in_progress = True + + # If the message_complete event isn't set yet, release the lock to + # allow put() to run and eventually set it. + # Locking with get_in_progress ensures only one task can get here. + if timeout is None: + completed = await self.message_complete.wait() + elif timeout <= 0: + completed = self.message_complete.is_set() + else: + try: + await asyncio.wait_for( + self.message_complete.wait(), timeout=timeout + ) + except asyncio.TimeoutError: + ... + finally: + completed = self.message_complete.is_set() + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + if not self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." + ) + self.get_in_progress = False + + # Waiting for a complete message timed out. + if not completed: + return None + if not self.message_complete.is_set(): + return None + + self.message_complete.clear() + + joiner: Data = b"" if self.decoder is None else "" + # mypy cannot figure out that chunks have the proper type. + message: Data = joiner.join(self.chunks) # type: ignore + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is here + # as a failsafe + raise ServerError( + "Websocket get() found a message when " + "state was already fetched." + ) + self.message_fetched.set() + self.chunks = [] + # this should already be None, but set it here for safety + self.chunks_queue = None + return message + + async def get_iter(self) -> AsyncIterator[Data]: + """ + Stream the next message. + Iterating the return value of :meth:`get_iter` yields a :class:`str` + or :class:`bytes` for each frame in the message. + """ + async with self.read_mutex: + if self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe + raise ServerError( + "Called get_iter on Websocket frame assembler " + "while asynchronous get is already in progress." + ) + self.get_in_progress = True + + chunks = self.chunks + self.chunks = [] + self.chunks_queue = asyncio.Queue() + + # Sending None in chunk_queue supersedes setting message_complete + # when switching to "streaming". If message is already complete + # when the switch happens, put() didn't send None, so we have to. + if self.message_complete.is_set(): + await self.chunks_queue.put(None) + + # Locking with get_in_progress ensures only one task can get here + for c in chunks: + yield c + while True: + chunk = await self.chunks_queue.get() + if chunk is None: + break + yield chunk + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + if not self.get_in_progress: + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." + ) + self.get_in_progress = False + if not self.message_complete.is_set(): + # This should be guarded against with the read_mutex, + # exception is here as a failsafe + raise ServerError( + "Websocket frame assembler chunks queue ended before " + "message was complete." + ) + self.message_complete.clear() + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is + # here as a failsafe + raise ServerError( + "Websocket get_iter() found a message when state was " + "already fetched." + ) + + self.message_fetched.set() + # this should already be empty, but set it here for safety + self.chunks = [] + self.chunks_queue = None + + async def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + When ``frame`` is the final frame in a message, :meth:`put` waits + until the message is fetched, either by calling :meth:`get` or by + iterating the return value of :meth:`get_iter`. + :meth:`put` assumes that the stream of frames respects the protocol. + If it doesn't, the behavior is undefined. + """ + + async with self.write_mutex: + if frame.opcode is Opcode.TEXT: + self.decoder = UTF8Decoder(errors="strict") + elif frame.opcode is Opcode.BINARY: + self.decoder = None + elif frame.opcode is Opcode.CONT: + pass + else: + # Ignore control frames. + return + data: Data + if self.decoder is not None: + data = self.decoder.decode(frame.data, frame.fin) + else: + data = frame.data + if self.chunks_queue is None: + self.chunks.append(data) + else: + await self.chunks_queue.put(data) + + if not frame.fin: + return + if not self.get_in_progress: + # nobody is waiting for this frame, so try to pause subsequent + # frames at the protocol level + self.paused = self.protocol.pause_frames() + # Message is complete. Wait until it's fetched to return. + + if self.chunks_queue is not None: + await self.chunks_queue.put(None) + if self.message_complete.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when a message was " + "already in its chamber." + ) + self.message_complete.set() # Signal to get() it can serve the + if self.message_fetched.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when the previous " + "message was not yet fetched." + ) + + # Allow get() to run and eventually set the event. + await self.message_fetched.wait() + self.message_fetched.clear() + self.decoder = None diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py new file mode 100644 index 0000000000..ed0d7fed81 --- /dev/null +++ b/sanic/server/websockets/impl.py @@ -0,0 +1,834 @@ +import asyncio +import random +import struct + +from typing import ( + AsyncIterator, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Union, +) + +from websockets.connection import CLOSED, CLOSING, OPEN, Event +from websockets.exceptions import ConnectionClosed, ConnectionClosedError +from websockets.frames import Frame, Opcode +from websockets.server import ServerConnection +from websockets.typing import Data + +from sanic.log import error_logger, logger +from sanic.server.protocols.base_protocol import SanicProtocol + +from ...exceptions import ServerError, WebsocketClosed +from .frame import WebsocketFrameAssembler + + +class WebsocketImplProtocol: + connection: ServerConnection + io_proto: Optional[SanicProtocol] + loop: Optional[asyncio.AbstractEventLoop] + max_queue: int + close_timeout: float + ping_interval: Optional[float] + ping_timeout: Optional[float] + assembler: WebsocketFrameAssembler + # Dict[bytes, asyncio.Future[None]] + pings: Dict[bytes, asyncio.Future] + conn_mutex: asyncio.Lock + recv_lock: asyncio.Lock + recv_cancel: Optional[asyncio.Future] + process_event_mutex: asyncio.Lock + can_pause: bool + # Optional[asyncio.Future[None]] + data_finished_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + pause_frame_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + connection_lost_waiter: Optional[asyncio.Future] + keepalive_ping_task: Optional[asyncio.Task] + auto_closer_task: Optional[asyncio.Task] + + def __init__( + self, + connection, + max_queue=None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: float = 10, + loop=None, + ): + self.connection = connection + self.io_proto = None + self.loop = None + self.max_queue = max_queue + self.close_timeout = close_timeout + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.assembler = WebsocketFrameAssembler(self) + self.pings = {} + self.conn_mutex = asyncio.Lock() + self.recv_lock = asyncio.Lock() + self.recv_cancel = None + self.process_event_mutex = asyncio.Lock() + self.data_finished_fut = None + self.can_pause = True + self.pause_frame_fut = None + self.keepalive_ping_task = None + self.auto_closer_task = None + self.connection_lost_waiter = None + + @property + def subprotocol(self): + return self.connection.subprotocol + + def pause_frames(self): + if not self.can_pause: + return False + if self.pause_frame_fut: + logger.debug("Websocket connection already paused.") + return False + if (not self.loop) or (not self.io_proto): + return False + if self.io_proto.transport: + self.io_proto.transport.pause_reading() + self.pause_frame_fut = self.loop.create_future() + logger.debug("Websocket connection paused.") + return True + + def resume_frames(self): + if not self.pause_frame_fut: + logger.debug("Websocket connection not paused.") + return False + if (not self.loop) or (not self.io_proto): + logger.debug( + "Websocket attempting to resume reading frames, " + "but connection is gone." + ) + return False + if self.io_proto.transport: + self.io_proto.transport.resume_reading() + self.pause_frame_fut.set_result(None) + self.pause_frame_fut = None + logger.debug("Websocket connection unpaused.") + return True + + async def connection_made( + self, + io_proto: SanicProtocol, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + if not loop: + try: + loop = getattr(io_proto, "loop") + except AttributeError: + loop = asyncio.get_event_loop() + if not loop: + # This catch is for mypy type checker + # to assert loop is not None here. + raise ServerError("Connection received with no asyncio loop.") + if self.auto_closer_task: + raise ServerError( + "Cannot call connection_made more than once " + "on a websocket connection." + ) + self.loop = loop + self.io_proto = io_proto + self.connection_lost_waiter = self.loop.create_future() + self.data_finished_fut = asyncio.shield(self.loop.create_future()) + + if self.ping_interval: + self.keepalive_ping_task = asyncio.create_task( + self.keepalive_ping() + ) + self.auto_closer_task = asyncio.create_task( + self.auto_close_connection() + ) + + async def wait_for_connection_lost(self, timeout=None) -> bool: + """ + Wait until the TCP connection is closed or ``timeout`` elapses. + If timeout is None, wait forever. + Recommend you should pass in self.close_timeout as timeout + + Return ``True`` if the connection is closed and ``False`` otherwise. + + """ + if not self.connection_lost_waiter: + return False + if self.connection_lost_waiter.done(): + return True + else: + try: + await asyncio.wait_for( + asyncio.shield(self.connection_lost_waiter), timeout + ) + return True + except asyncio.TimeoutError: + # Re-check self.connection_lost_waiter.done() synchronously + # because connection_lost() could run between the moment the + # timeout occurs and the moment this coroutine resumes running + return self.connection_lost_waiter.done() + + async def process_events(self, events: Sequence[Event]) -> None: + """ + Process a list of incoming events. + """ + # Wrapped in a mutex lock, to prevent other incoming events + # from processing at the same time + async with self.process_event_mutex: + for event in events: + if not isinstance(event, Frame): + # Event is not a frame. Ignore it. + continue + if event.opcode == Opcode.PONG: + await self.process_pong(event) + elif event.opcode == Opcode.CLOSE: + if self.recv_cancel: + self.recv_cancel.cancel() + else: + await self.assembler.put(event) + + async def process_pong(self, frame: Frame) -> None: + if frame.data in self.pings: + # Acknowledge all pings up to the one matching this pong. + ping_ids = [] + for ping_id, ping in self.pings.items(): + ping_ids.append(ping_id) + if not ping.done(): + ping.set_result(None) + if ping_id == frame.data: + break + else: # noqa + raise ServerError("ping_id is not in self.pings") + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + + async def keepalive_ping(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + This coroutine exits when the connection terminates and one of the + following happens: + - :meth:`ping` raises :exc:`ConnectionClosed`, or + - :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`. + """ + if self.ping_interval is None: + return + + try: + while True: + await asyncio.sleep(self.ping_interval) + + # ping() raises CancelledError if the connection is closed, + # when auto_close_connection() cancels keepalive_ping_task. + + # ping() raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). + + ping_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + await asyncio.wait_for(ping_waiter, self.ping_timeout) + except asyncio.TimeoutError: + error_logger.warning( + "Websocket timed out waiting for pong" + ) + self.fail_connection(1011) + break + except asyncio.CancelledError: + # It is expected for this task to be cancelled during during + # normal operation, when the connection is closed. + logger.debug("Websocket keepalive ping task was cancelled.") + except (ConnectionClosed, WebsocketClosed): + logger.debug("Websocket closed. Keepalive ping task exiting.") + except Exception as e: + error_logger.warning( + "Unexpected exception in websocket keepalive ping task." + ) + logger.debug(str(e)) + + def _force_disconnect(self) -> bool: + """ + Internal methdod used by end_connection and fail_connection + only when the graceful auto-closer cannot be used + """ + if self.auto_closer_task and not self.auto_closer_task.done(): + self.auto_closer_task.cancel() + if self.data_finished_fut and not self.data_finished_fut.done(): + self.data_finished_fut.cancel() + self.data_finished_fut = None + if self.keepalive_ping_task and not self.keepalive_ping_task.done(): + self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None + if self.loop and self.io_proto and self.io_proto.transport: + self.io_proto.transport.close() + self.loop.call_later( + self.close_timeout, self.io_proto.transport.abort + ) + # We were never open, or already closed + return True + + def fail_connection(self, code: int = 1006, reason: str = "") -> bool: + """ + Fail the WebSocket Connection + This requires: + 1. Stopping all processing of incoming data, which means cancelling + pausing the underlying io protocol. The close code will be 1006 + unless a close frame was received earlier. + 2. Sending a close frame with an appropriate code if the opening + handshake succeeded and the other side is likely to process it. + 3. Closing the connection. :meth:`auto_close_connection` takes care + of this. + (The specification describes these steps in the opposite order.) + """ + if self.io_proto and self.io_proto.transport: + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # ut can be called when the transport is already paused or closed + self.io_proto.transport.pause_reading() + + # Keeping fail_connection() synchronous guarantees it can't + # get stuck and simplifies the implementation of the callers. + # Not draining the write buffer is acceptable in this context. + + # clear the send buffer + _ = self.connection.data_to_send() + # If we're not already CLOSED or CLOSING, then send the close. + if self.connection.state is OPEN: + if code in (1000, 1001): + self.connection.send_close(code, reason) + else: + self.connection.fail(code, reason) + try: + data_to_send = self.connection.data_to_send() + while ( + len(data_to_send) + and self.io_proto + and self.io_proto.transport + ): + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + ... + if code == 1006: + # Special case: 1006 consider the transport already closed + self.connection.state = CLOSED + if self.data_finished_fut and not self.data_finished_fut.done(): + # We have a graceful auto-closer. Use it to close the connection. + self.data_finished_fut.cancel() + self.data_finished_fut = None + if (not self.auto_closer_task) or self.auto_closer_task.done(): + return self._force_disconnect() + return False + + def end_connection(self, code=1000, reason=""): + # This is like slightly more graceful form of fail_connection + # Use this instead of close() when you need an immediate + # close and cannot await websocket.close() handshake. + + if code == 1006 or not self.io_proto or not self.io_proto.transport: + return self.fail_connection(code, reason) + + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # i.e. it can be called when the transport is already paused or closed. + self.io_proto.transport.pause_reading() + if self.connection.state == OPEN: + data_to_send = self.connection.data_to_send() + self.connection.send_close(code, reason) + data_to_send.extend(self.connection.data_to_send()) + try: + while ( + len(data_to_send) + and self.io_proto + and self.io_proto.transport + ): + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + # But that doesn't matter at this point + ... + if self.data_finished_fut and not self.data_finished_fut.done(): + # We have the ability to signal the auto-closer + # try to trigger it to auto-close the connection + self.data_finished_fut.cancel() + self.data_finished_fut = None + if (not self.auto_closer_task) or self.auto_closer_task.done(): + # Auto-closer is not running, do force disconnect + return self._force_disconnect() + return False + + async def auto_close_connection(self) -> None: + """ + Close the WebSocket Connection + When the opening handshake succeeds, :meth:`connection_open` starts + this coroutine in a task. It waits for the data transfer phase to + complete then it closes the TCP connection cleanly. + When the opening handshake fails, :meth:`fail_connection` does the + same. There's no data transfer phase in that case. + """ + try: + # Wait for the data transfer phase to complete. + if self.data_finished_fut: + try: + await self.data_finished_fut + logger.debug( + "Websocket task finished. Closing the connection." + ) + except asyncio.CancelledError: + # Cancelled error is called when data phase is cancelled + # if an error occurred or the client closed the connection + logger.debug( + "Websocket handler cancelled. Closing the connection." + ) + + # Cancel the keepalive ping task. + if self.keepalive_ping_task: + self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None + + # Half-close the TCP connection if possible (when there's no TLS). + if ( + self.io_proto + and self.io_proto.transport + and self.io_proto.transport.can_write_eof() + ): + logger.debug("Websocket half-closing TCP connection") + self.io_proto.transport.write_eof() + if self.connection_lost_waiter: + if await self.wait_for_connection_lost(timeout=0): + return + except asyncio.CancelledError: + ... + finally: + # The try/finally ensures that the transport never remains open, + # even if this coroutine is cancelled (for example). + if (not self.io_proto) or (not self.io_proto.transport): + # we were never open, or done. Can't do any finalization. + return + elif ( + self.connection_lost_waiter + and self.connection_lost_waiter.done() + ): + # connection confirmed closed already, proceed to abort waiter + ... + elif self.io_proto.transport.is_closing(): + # Connection is already closing (due to half-close above) + # proceed to abort waiter + ... + else: + self.io_proto.transport.close() + if not self.connection_lost_waiter: + # Our connection monitor task isn't running. + try: + await asyncio.sleep(self.close_timeout) + except asyncio.CancelledError: + ... + if self.io_proto and self.io_proto.transport: + self.io_proto.transport.abort() + else: + if await self.wait_for_connection_lost( + timeout=self.close_timeout + ): + # Connection aborted before the timeout expired. + return + error_logger.warning( + "Timeout waiting for TCP connection to close. Aborting" + ) + if self.io_proto and self.io_proto.transport: + self.io_proto.transport.abort() + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending keepalive pings. + They'll never receive a pong once the connection is closed. + """ + if self.connection.state is not CLOSED: + raise ServerError( + "Webscoket about_pings should only be called " + "after connection state is changed to CLOSED" + ) + + for ping in self.pings.values(): + ping.set_exception(ConnectionClosedError(None, None)) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + ping.cancel() + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + This is a websocket-protocol level close. + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + :param code: WebSocket close code + :param reason: WebSocket close reason + """ + if code == 1006: + self.fail_connection(code, reason) + return + async with self.conn_mutex: + if self.connection.state is OPEN: + self.connection.send_close(code, reason) + data_to_send = self.connection.data_to_send() + await self.send_data(data_to_send) + + async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Receive the next message. + Return a :class:`str` for a text frame and :class:`bytes` for a binary + frame. + When the end of the message stream is reached, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + If ``timeout`` is ``None``, block until a message is received. Else, + if no message is received within ``timeout`` seconds, return ``None``. + Set ``timeout`` to ``0`` to check if a message was already received. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises asyncio.CancelledError: if the websocket closes while waiting + :raises ServerError: if two tasks call :meth:`recv` or + :meth:`recv_streaming` concurrently + """ + + if self.recv_lock.locked(): + raise ServerError( + "cannot call recv while another task is " + "already waiting for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( + "Cannot receive from websocket interface after it is closed." + ) + try: + self.recv_cancel = asyncio.Future() + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + else: + self.recv_cancel.cancel() + return done_task.result() + finally: + self.recv_cancel = None + self.recv_lock.release() + + async def recv_burst(self, max_recv=256) -> Sequence[Data]: + """ + Receive the messages which have arrived since last checking. + Return a :class:`list` containing :class:`str` for a text frame + and :class:`bytes` for a binary frame. + When the end of the message stream is reached, :meth:`recv_burst` + raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a + normal connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises ServerError: if two tasks call :meth:`recv_burst` or + :meth:`recv_streaming` concurrently + """ + + if self.recv_lock.locked(): + raise ServerError( + "cannot call recv_burst while another task is already waiting " + "for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( + "Cannot receive from websocket interface after it is closed." + ) + messages = [] + try: + # Prevent pausing the transport when we're + # receiving a burst of messages + self.can_pause = False + self.recv_cancel = asyncio.Future() + while True: + done, pending = await asyncio.wait( + (self.recv_cancel, self.assembler.get(timeout=0)), + return_when=asyncio.FIRST_COMPLETED, + ) + done_task = next(iter(done)) + if done_task is self.recv_cancel: + # recv_burst was cancelled + for p in pending: + p.cancel() + raise asyncio.CancelledError() + m = done_task.result() + if m is None: + # None left in the burst. This is good! + break + messages.append(m) + if len(messages) >= max_recv: + # Too much data in the pipe. Hit our burst limit. + break + # Allow an eventloop iteration for the + # next message to pass into the Assembler + await asyncio.sleep(0) + self.recv_cancel.cancel() + finally: + self.recv_cancel = None + self.can_pause = True + self.recv_lock.release() + return messages + + async def recv_streaming(self) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + Return an iterator of :class:`str` for a text frame and :class:`bytes` + for a binary frame. The iterator should be exhausted, or else the + connection will become unusable. + With the exception of the return value, :meth:`recv_streaming` behaves + like :meth:`recv`. + """ + if self.recv_lock.locked(): + raise ServerError( + "Cannot call recv_streaming while another task " + "is already waiting for the next message" + ) + await self.recv_lock.acquire() + if self.connection.state is CLOSED: + self.recv_lock.release() + raise WebsocketClosed( + "Cannot receive from websocket interface after it is closed." + ) + try: + cancelled = False + self.recv_cancel = asyncio.Future() + self.can_pause = False + async for m in self.assembler.get_iter(): + if self.recv_cancel.done(): + cancelled = True + break + yield m + if cancelled: + raise asyncio.CancelledError() + finally: + self.can_pause = True + self.recv_cancel = None + self.recv_lock.release() + + async def send(self, message: Union[Data, Iterable[Data]]) -> None: + """ + Send a message. + A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a `Binary frame`_. + .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + :meth:`send` also accepts an iterable of strings, bytestrings, or + bytes-like objects. In that case the message is fragmented. Each item + is treated as a message fragment and sent in its own frame. All items + must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + :meth:`send` rejects dict-like objects because this is often an error. + If you wish to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`. + :raises TypeError: for unsupported inputs + """ + async with self.conn_mutex: + + if self.connection.state in (CLOSED, CLOSING): + raise WebsocketClosed( + "Cannot write to websocket interface after it is closed." + ) + if (not self.data_finished_fut) or self.data_finished_fut.done(): + raise ServerError( + "Cannot write to websocket interface after it is finished." + ) + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + self.connection.send_text(message.encode("utf-8")) + await self.send_data(self.connection.data_to_send()) + + elif isinstance(message, (bytes, bytearray, memoryview)): + self.connection.send_binary(message) + await self.send_data(self.connection.data_to_send()) + + elif isinstance(message, Mapping): + # Catch a common mistake -- passing a dict to send(). + raise TypeError("data is a dict-like object") + + elif isinstance(message, Iterable): + # Fragmented message -- regular iterator. + raise NotImplementedError( + "Fragmented websocket messages are not supported." + ) + else: + raise TypeError("Websocket data must be bytes, str.") + + async def ping(self, data: Optional[Data] = None) -> asyncio.Future: + """ + Send a ping. + Return an :class:`~asyncio.Future` that will be resolved when the + corresponding pong is received. You can ignore it if you don't intend + to wait. + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point:: + await pong_event = ws.ping() + await pong_event # only if you want to wait for the pong + By default, the ping contains four random bytes. This payload may be + overridden with the optional ``data`` argument which must be a string + (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + if self.connection.state in (CLOSED, CLOSING): + raise WebsocketClosed( + "Cannot send a ping when the websocket interface " + "is closed." + ) + if (not self.io_proto) or (not self.io_proto.loop): + raise ServerError( + "Cannot send a ping when the websocket has no I/O " + "protocol attached." + ) + if data is not None: + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) + + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise ValueError( + "already waiting for a pong with the same data" + ) + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + self.pings[data] = self.io_proto.loop.create_future() + + self.connection.send_ping(data) + await self.send_data(self.connection.data_to_send()) + + return asyncio.shield(self.pings[data]) + + async def pong(self, data: Data = b"") -> None: + """ + Send a pong. + An unsolicited pong may serve as a unidirectional heartbeat. + The payload may be set with the optional ``data`` argument which must + be a string (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + if self.connection.state in (CLOSED, CLOSING): + # Cannot send pong after transport is shutting down + return + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) + self.connection.send_pong(data) + await self.send_data(self.connection.data_to_send()) + + async def send_data(self, data_to_send): + for data in data_to_send: + if data: + await self.io_proto.send(data) + else: + # Send an EOF - We don't actually send it, + # just trigger to autoclose the connection + if ( + self.auto_closer_task + and not self.auto_closer_task.done() + and self.data_finished_fut + and not self.data_finished_fut.done() + ): + # Auto-close the connection + self.data_finished_fut.set_result(None) + else: + # This will fail the connection appropriately + SanicProtocol.close(self.io_proto, timeout=1.0) + + async def async_data_received(self, data_to_send, events_to_process): + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: + # receiving data can generate data to send (eg, pong for a ping) + # send connection.data_to_send() + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + + def data_received(self, data): + self.connection.receive_data(data) + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + if len(data_to_send) > 0 or len(events_to_process) > 0: + asyncio.create_task( + self.async_data_received(data_to_send, events_to_process) + ) + + async def async_eof_received(self, data_to_send, events_to_process): + # receiving EOF can generate data to send + # send connection.data_to_send() + if self.connection.state in (OPEN, CLOSING) and len(data_to_send) > 0: + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + if self.recv_cancel: + self.recv_cancel.cancel() + if ( + self.auto_closer_task + and not self.auto_closer_task.done() + and self.data_finished_fut + and not self.data_finished_fut.done() + ): + # Auto-close the connection + self.data_finished_fut.set_result(None) + # Cancel the running handler if its waiting + else: + # This will fail the connection appropriately + SanicProtocol.close(self.io_proto, timeout=1.0) + + def eof_received(self) -> Optional[bool]: + self.connection.receive_eof() + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + asyncio.create_task( + self.async_eof_received(data_to_send, events_to_process) + ) + return False + + def connection_lost(self, exc): + """ + The WebSocket Connection is Closed. + """ + if not self.connection.state == CLOSED: + # signal to the websocket connection handler + # we've lost the connection + self.connection.fail(code=1006) + self.connection.state = CLOSED + + self.abort_pings() + if self.connection_lost_waiter: + self.connection_lost_waiter.set_result(None) diff --git a/sanic/signals.py b/sanic/signals.py index eec2a43858..2c1a704cc7 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -10,13 +10,39 @@ from sanic_routing.utils import path_to_parts # type: ignore from sanic.exceptions import InvalidSignal +from sanic.log import error_logger, logger from sanic.models.handler_types import SignalHandler -RESERVED_NAMESPACES = ( - "server", - "http", -) +RESERVED_NAMESPACES = { + "server": ( + # "server.main.start", + # "server.main.stop", + "server.init.before", + "server.init.after", + "server.shutdown.before", + "server.shutdown.after", + ), + "http": ( + "http.lifecycle.begin", + "http.lifecycle.complete", + "http.lifecycle.exception", + "http.lifecycle.handle", + "http.lifecycle.read_body", + "http.lifecycle.read_head", + "http.lifecycle.request", + "http.lifecycle.response", + "http.routing.after", + "http.routing.before", + "http.lifecycle.send", + "http.middleware.after", + "http.middleware.before", + ), +} + + +def _blank(): + ... class Signal(Route): @@ -59,8 +85,13 @@ def get( # type: ignore terms.append(extra) raise NotFound(message % tuple(terms)) + # Regex routes evaluate and can extract params directly. They are set + # on param_basket["__params__"] params = param_basket["__params__"] if not params: + # If param_basket["__params__"] does not exist, we might have + # param_basket["__matches__"], which are indexed based matches + # on path segments. They should already be cast types. params = { param.name: param_basket["__matches__"][idx] for idx, param in group.params.items() @@ -73,8 +104,18 @@ async def _dispatch( event: str, context: Optional[Dict[str, Any]] = None, condition: Optional[Dict[str, str]] = None, - ) -> None: - group, handlers, params = self.get(event, condition=condition) + fail_not_found: bool = True, + reverse: bool = False, + ) -> Any: + try: + group, handlers, params = self.get(event, condition=condition) + except NotFound as e: + if fail_not_found: + raise e + else: + if self.ctx.app.debug: + error_logger.warning(str(e)) + return None events = [signal.ctx.event for signal in group] for signal_event in events: @@ -82,12 +123,19 @@ async def _dispatch( if context: params.update(context) + if not reverse: + handlers = handlers[::-1] try: for handler in handlers: if condition is None or condition == handler.__requirements__: maybe_coroutine = handler(**params) if isawaitable(maybe_coroutine): - await maybe_coroutine + retval = await maybe_coroutine + if retval: + return retval + elif maybe_coroutine: + return maybe_coroutine + return None finally: for signal_event in events: signal_event.clear() @@ -98,14 +146,23 @@ async def dispatch( *, context: Optional[Dict[str, Any]] = None, condition: Optional[Dict[str, str]] = None, - ) -> asyncio.Task: - task = self.ctx.loop.create_task( - self._dispatch( - event, - context=context, - condition=condition, - ) + fail_not_found: bool = True, + inline: bool = False, + reverse: bool = False, + ) -> Union[asyncio.Task, Any]: + dispatch = self._dispatch( + event, + context=context, + condition=condition, + fail_not_found=fail_not_found and inline, + reverse=reverse, ) + logger.debug(f"Dispatching signal: {event}") + + if inline: + return await dispatch + + task = asyncio.get_running_loop().create_task(dispatch) await asyncio.sleep(0) return task @@ -131,7 +188,9 @@ def add( # type: ignore append=True, ) # type: ignore - def finalize(self, do_compile: bool = True): + def finalize(self, do_compile: bool = True, do_optimize: bool = False): + self.add(_blank, "sanic.__signal__.__init__") + try: self.ctx.loop = asyncio.get_running_loop() except RuntimeError: @@ -140,7 +199,7 @@ def finalize(self, do_compile: bool = True): for signal in self.routes: signal.ctx.event = asyncio.Event() - return super().finalize(do_compile=do_compile) + return super().finalize(do_compile=do_compile, do_optimize=do_optimize) def _build_event_parts(self, event: str) -> Tuple[str, str, str]: parts = path_to_parts(event, self.delimiter) @@ -151,7 +210,11 @@ def _build_event_parts(self, event: str) -> Tuple[str, str, str]: ): raise InvalidSignal("Invalid signal event: %s" % event) - if parts[0] in RESERVED_NAMESPACES: + if ( + parts[0] in RESERVED_NAMESPACES + and event not in RESERVED_NAMESPACES[parts[0]] + and not (parts[2].startswith("<") and parts[2].endswith(">")) + ): raise InvalidSignal( "Cannot declare reserved signal event: %s" % event ) diff --git a/sanic/touchup/__init__.py b/sanic/touchup/__init__.py new file mode 100644 index 0000000000..6fe208abb3 --- /dev/null +++ b/sanic/touchup/__init__.py @@ -0,0 +1,8 @@ +from .meta import TouchUpMeta +from .service import TouchUp + + +__all__ = ( + "TouchUp", + "TouchUpMeta", +) diff --git a/sanic/touchup/meta.py b/sanic/touchup/meta.py new file mode 100644 index 0000000000..9f60af387f --- /dev/null +++ b/sanic/touchup/meta.py @@ -0,0 +1,22 @@ +from sanic.exceptions import SanicException + +from .service import TouchUp + + +class TouchUpMeta(type): + def __new__(cls, name, bases, attrs, **kwargs): + gen_class = super().__new__(cls, name, bases, attrs, **kwargs) + + methods = attrs.get("__touchup__") + attrs["__touched__"] = False + if methods: + + for method in methods: + if method not in attrs: + raise SanicException( + "Cannot perform touchup on non-existent method: " + f"{name}.{method}" + ) + TouchUp.register(gen_class, method) + + return gen_class diff --git a/sanic/touchup/schemes/__init__.py b/sanic/touchup/schemes/__init__.py new file mode 100644 index 0000000000..87057a5fce --- /dev/null +++ b/sanic/touchup/schemes/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseScheme +from .ode import OptionalDispatchEvent # noqa + + +__all__ = ("BaseScheme",) diff --git a/sanic/touchup/schemes/base.py b/sanic/touchup/schemes/base.py new file mode 100644 index 0000000000..d16619b2f8 --- /dev/null +++ b/sanic/touchup/schemes/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import Set, Type + + +class BaseScheme(ABC): + ident: str + _registry: Set[Type] = set() + + def __init__(self, app) -> None: + self.app = app + + @abstractmethod + def run(self, method, module_globals) -> None: + ... + + def __init_subclass__(cls): + BaseScheme._registry.add(cls) + + def __call__(self, method, module_globals): + return self.run(method, module_globals) diff --git a/sanic/touchup/schemes/ode.py b/sanic/touchup/schemes/ode.py new file mode 100644 index 0000000000..357f748c84 --- /dev/null +++ b/sanic/touchup/schemes/ode.py @@ -0,0 +1,67 @@ +from ast import Attribute, Await, Dict, Expr, NodeTransformer, parse +from inspect import getsource +from textwrap import dedent +from typing import Any + +from sanic.log import logger + +from .base import BaseScheme + + +class OptionalDispatchEvent(BaseScheme): + ident = "ODE" + + def __init__(self, app) -> None: + super().__init__(app) + + self._registered_events = [ + signal.path for signal in app.signal_router.routes + ] + + def run(self, method, module_globals): + raw_source = getsource(method) + src = dedent(raw_source) + tree = parse(src) + node = RemoveDispatch(self._registered_events).visit(tree) + compiled_src = compile(node, method.__name__, "exec") + exec_locals: Dict[str, Any] = {} + exec(compiled_src, module_globals, exec_locals) # nosec + + return exec_locals[method.__name__] + + +class RemoveDispatch(NodeTransformer): + def __init__(self, registered_events) -> None: + self._registered_events = registered_events + + def visit_Expr(self, node: Expr) -> Any: + call = node.value + if isinstance(call, Await): + call = call.value + + func = getattr(call, "func", None) + args = getattr(call, "args", None) + if not func or not args: + return node + + if isinstance(func, Attribute) and func.attr == "dispatch": + event = args[0] + if hasattr(event, "s"): + event_name = getattr(event, "value", event.s) + if self._not_registered(event_name): + logger.debug(f"Disabling event: {event_name}") + return None + return node + + def _not_registered(self, event_name): + dynamic = [] + for event in self._registered_events: + if event.endswith(">"): + namespace_concern, _ = event.rsplit(".", 1) + dynamic.append(namespace_concern) + + namespace_concern, _ = event_name.rsplit(".", 1) + return ( + event_name not in self._registered_events + and namespace_concern not in dynamic + ) diff --git a/sanic/touchup/service.py b/sanic/touchup/service.py new file mode 100644 index 0000000000..95792dca10 --- /dev/null +++ b/sanic/touchup/service.py @@ -0,0 +1,33 @@ +from inspect import getmembers, getmodule +from typing import Set, Tuple, Type + +from .schemes import BaseScheme + + +class TouchUp: + _registry: Set[Tuple[Type, str]] = set() + + @classmethod + def run(cls, app): + for target, method_name in cls._registry: + method = getattr(target, method_name) + + if app.test_mode: + placeholder = f"_{method_name}" + if hasattr(target, placeholder): + method = getattr(target, placeholder) + else: + setattr(target, placeholder, method) + + module = getmodule(target) + module_globals = dict(getmembers(module)) + + for scheme in BaseScheme._registry: + modified = scheme(app)(method, module_globals) + setattr(target, method_name, modified) + + target.__touched__ = True + + @classmethod + def register(cls, target, method_name): + cls._registry.add((target, method_name)) diff --git a/sanic/views.py b/sanic/views.py index 64a872a46c..c983bef750 100644 --- a/sanic/views.py +++ b/sanic/views.py @@ -13,6 +13,7 @@ from sanic.constants import HTTP_METHODS from sanic.exceptions import InvalidUsage +from sanic.models.handler_types import RouteHandler if TYPE_CHECKING: @@ -86,7 +87,7 @@ def dispatch_request(self, request, *args, **kwargs): return handler(request, *args, **kwargs) @classmethod - def as_view(cls, *class_args, **class_kwargs): + def as_view(cls, *class_args: Any, **class_kwargs: Any) -> RouteHandler: """Return view function for use with the routing system, that dispatches request to appropriate handler method. """ @@ -100,7 +101,7 @@ def view(*args, **kwargs): for decorator in cls.decorators: view = decorator(view) - view.view_class = cls + view.view_class = cls # type: ignore view.__doc__ = cls.__doc__ view.__module__ = cls.__module__ view.__name__ = cls.__name__ diff --git a/sanic/websocket.py b/sanic/websocket.py deleted file mode 100644 index b5600ed779..0000000000 --- a/sanic/websocket.py +++ /dev/null @@ -1,205 +0,0 @@ -from typing import ( - Any, - Awaitable, - Callable, - Dict, - List, - MutableMapping, - Optional, - Union, -) - -from httptools import HttpParserUpgrade # type: ignore -from websockets import ( # type: ignore - ConnectionClosed, - InvalidHandshake, - WebSocketCommonProtocol, -) - -# Despite the "legacy" namespace, the primary maintainer of websockets -# committed to maintaining backwards-compatibility until 2026 and will -# consider extending it if sanic continues depending on this module. -from websockets.legacy import handshake - -from sanic.exceptions import InvalidUsage -from sanic.server import HttpProtocol - - -__all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"] - -ASIMessage = MutableMapping[str, Any] - - -class WebSocketProtocol(HttpProtocol): - def __init__( - self, - *args, - websocket_timeout=10, - websocket_max_size=None, - websocket_max_queue=None, - websocket_read_limit=2 ** 16, - websocket_write_limit=2 ** 16, - websocket_ping_interval=20, - websocket_ping_timeout=20, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.websocket = None - # self.app = None - self.websocket_timeout = websocket_timeout - self.websocket_max_size = websocket_max_size - self.websocket_max_queue = websocket_max_queue - self.websocket_read_limit = websocket_read_limit - self.websocket_write_limit = websocket_write_limit - self.websocket_ping_interval = websocket_ping_interval - self.websocket_ping_timeout = websocket_ping_timeout - - # timeouts make no sense for websocket routes - def request_timeout_callback(self): - if self.websocket is None: - super().request_timeout_callback() - - def response_timeout_callback(self): - if self.websocket is None: - super().response_timeout_callback() - - def keep_alive_timeout_callback(self): - if self.websocket is None: - super().keep_alive_timeout_callback() - - def connection_lost(self, exc): - if self.websocket is not None: - self.websocket.connection_lost(exc) - super().connection_lost(exc) - - def data_received(self, data): - if self.websocket is not None: - # pass the data to the websocket protocol - self.websocket.data_received(data) - else: - try: - super().data_received(data) - except HttpParserUpgrade: - # this is okay, it just indicates we've got an upgrade request - pass - - def write_response(self, response): - if self.websocket is not None: - # websocket requests do not write a response - self.transport.close() - else: - super().write_response(response) - - async def websocket_handshake(self, request, subprotocols=None): - # let the websockets package do the handshake with the client - headers = {} - - try: - key = handshake.check_request(request.headers) - handshake.build_response(headers, key) - except InvalidHandshake: - raise InvalidUsage("Invalid websocket request") - - subprotocol = None - if subprotocols and "Sec-Websocket-Protocol" in request.headers: - # select a subprotocol - client_subprotocols = [ - p.strip() - for p in request.headers["Sec-Websocket-Protocol"].split(",") - ] - for p in client_subprotocols: - if p in subprotocols: - subprotocol = p - headers["Sec-Websocket-Protocol"] = subprotocol - break - - # write the 101 response back to the client - rv = b"HTTP/1.1 101 Switching Protocols\r\n" - for k, v in headers.items(): - rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" - rv += b"\r\n" - request.transport.write(rv) - - # hook up the websocket protocol - self.websocket = WebSocketCommonProtocol( - close_timeout=self.websocket_timeout, - max_size=self.websocket_max_size, - max_queue=self.websocket_max_queue, - read_limit=self.websocket_read_limit, - write_limit=self.websocket_write_limit, - ping_interval=self.websocket_ping_interval, - ping_timeout=self.websocket_ping_timeout, - ) - # we use WebSocketCommonProtocol because we don't want the handshake - # logic from WebSocketServerProtocol; however, we must tell it that - # we're running on the server side - self.websocket.is_client = False - self.websocket.side = "server" - self.websocket.subprotocol = subprotocol - self.websocket.connection_made(request.transport) - self.websocket.connection_open() - return self.websocket - - -class WebSocketConnection: - - # TODO - # - Implement ping/pong - - def __init__( - self, - send: Callable[[ASIMessage], Awaitable[None]], - receive: Callable[[], Awaitable[ASIMessage]], - subprotocols: Optional[List[str]] = None, - ) -> None: - self._send = send - self._receive = receive - self._subprotocols = subprotocols or [] - - async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: - message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} - - if isinstance(data, bytes): - message.update({"bytes": data}) - else: - message.update({"text": str(data)}) - - await self._send(message) - - async def recv(self, *args, **kwargs) -> Optional[str]: - message = await self._receive() - - if message["type"] == "websocket.receive": - return message["text"] - elif message["type"] == "websocket.disconnect": - pass - - return None - - receive = recv - - async def accept(self, subprotocols: Optional[List[str]] = None) -> None: - subprotocol = None - if subprotocols: - for subp in subprotocols: - if subp in self.subprotocols: - subprotocol = subp - break - - await self._send( - { - "type": "websocket.accept", - "subprotocol": subprotocol, - } - ) - - async def close(self) -> None: - pass - - @property - def subprotocols(self): - return self._subprotocols - - @subprotocols.setter - def subprotocols(self, subprotocols: Optional[List[str]] = None): - self._subprotocols = subprotocols or [] diff --git a/sanic/worker.py b/sanic/worker.py index 342900e6b1..a3bc29b8b8 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -8,8 +8,8 @@ from gunicorn.workers import base # type: ignore from sanic.log import logger -from sanic.server import HttpProtocol, Signal, serve, trigger_events -from sanic.websocket import WebSocketProtocol +from sanic.server import HttpProtocol, Signal, serve +from sanic.server.protocols.websocket_protocol import WebSocketProtocol try: @@ -68,10 +68,10 @@ def run(self): ) self._server_settings["signal"] = self.signal self._server_settings.pop("sock") - trigger_events( - self._server_settings.get("before_start", []), self.loop + self._await(self.app.callable._startup()) + self._await( + self.app.callable._server_event("init", "before", loop=self.loop) ) - self._server_settings["before_start"] = () main_start = self._server_settings.pop("main_start", None) main_stop = self._server_settings.pop("main_stop", None) @@ -82,24 +82,29 @@ def run(self): "with GunicornWorker" ) - self._runner = asyncio.ensure_future(self._run(), loop=self.loop) try: - self.loop.run_until_complete(self._runner) + self._await(self._run()) self.app.callable.is_running = True - trigger_events( - self._server_settings.get("after_start", []), self.loop + self._await( + self.app.callable._server_event( + "init", "after", loop=self.loop + ) ) self.loop.run_until_complete(self._check_alive()) - trigger_events( - self._server_settings.get("before_stop", []), self.loop + self._await( + self.app.callable._server_event( + "shutdown", "before", loop=self.loop + ) ) self.loop.run_until_complete(self.close()) except BaseException: traceback.print_exc() finally: try: - trigger_events( - self._server_settings.get("after_stop", []), self.loop + self._await( + self.app.callable._server_event( + "shutdown", "after", loop=self.loop + ) ) except BaseException: traceback.print_exc() @@ -137,14 +142,11 @@ async def close(self): # Force close non-idle connection after waiting for # graceful_shutdown_timeout - coros = [] for conn in self.connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) + conn.websocket.fail_connection(code=1001) else: - conn.close() - _shutdown = asyncio.gather(*coros, loop=self.loop) - await _shutdown + conn.abort() async def _run(self): for sock in self.sockets: @@ -238,3 +240,7 @@ def handle_abort(self, sig, frame): self.exit_code = 1 self.cfg.worker_abort(self) sys.exit(1) + + def _await(self, coro): + fut = asyncio.ensure_future(coro, loop=self.loop) + self.loop.run_until_complete(fut) diff --git a/setup.py b/setup.py index af347b3fbc..ecbf1e07c9 100644 --- a/setup.py +++ b/setup.py @@ -81,60 +81,63 @@ def open_local(paths, mode="r", encoding="utf8"): ) ujson = "ujson>=1.35" + env_dependency uvloop = "uvloop>=0.5.3" + env_dependency - +types_ujson = "types-ujson" + env_dependency requirements = [ "sanic-routing~=0.7", "httptools>=0.0.10", uvloop, ujson, "aiofiles>=0.6.0", - "websockets>=9.0", + "websockets>=10.0", "multidict>=5.0,<6.0", ] tests_require = [ - "sanic-testing>=0.7.0b1", + "sanic-testing>=0.7.0", "pytest==5.2.1", - "multidict>=5.0,<6.0", + "coverage==5.3", "gunicorn==20.0.4", "pytest-cov", "beautifulsoup4", - uvloop, - ujson, "pytest-sanic", "pytest-sugar", "pytest-benchmark", + "chardet==3.*", + "flake8", + "black", + "isort>=5.0.0", + "bandit", + "mypy>=0.901", + "docutils", + "pygments", + "uvicorn<0.15.0", + types_ujson, ] docs_require = [ "sphinx>=2.1.2", - "sphinx_rtd_theme", - "recommonmark>=0.5.0", + "sphinx_rtd_theme>=0.4.3", "docutils", "pygments", + "m2r2", ] dev_require = tests_require + [ - "aiofiles", "tox", - "black", - "flake8", - "bandit", "towncrier", ] -all_require = dev_require + docs_require +all_require = list(set(dev_require + docs_require)) if strtobool(os.environ.get("SANIC_NO_UJSON", "no")): print("Installing without uJSON") requirements.remove(ujson) - tests_require.remove(ujson) + tests_require.remove(types_ujson) # 'nt' means windows OS if strtobool(os.environ.get("SANIC_NO_UVLOOP", "no")): print("Installing without uvLoop") requirements.remove(uvloop) - tests_require.remove(uvloop) extras_require = { "test": tests_require, diff --git a/tests/conftest.py b/tests/conftest.py index 65b218cf84..175e967efa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +import asyncio +import logging import random import re import string @@ -9,10 +11,12 @@ import pytest from sanic_routing.exceptions import RouteExists +from sanic_testing.testing import PORT from sanic import Sanic from sanic.constants import HTTP_METHODS from sanic.router import Router +from sanic.touchup.service import TouchUp slugify = re.compile(r"[^a-zA-Z0-9_\-]") @@ -23,11 +27,6 @@ collect_ignore = ["test_worker.py"] -@pytest.fixture -def caplog(caplog): - yield caplog - - async def _handler(request): """ Dummy placeholder method used for route resolver when creating a new @@ -41,33 +40,32 @@ async def _handler(request): TYPE_TO_GENERATOR_MAP = { - "string": lambda: "".join( + "str": lambda: "".join( [random.choice(string.ascii_lowercase) for _ in range(4)] ), "int": lambda: random.choice(range(1000000)), - "number": lambda: random.random(), + "float": lambda: random.random(), "alpha": lambda: "".join( [random.choice(string.ascii_lowercase) for _ in range(4)] ), "uuid": lambda: str(uuid.uuid1()), } +CACHE = {} + class RouteStringGenerator: ROUTE_COUNT_PER_DEPTH = 100 HTTP_METHODS = HTTP_METHODS - ROUTE_PARAM_TYPES = ["string", "int", "number", "alpha", "uuid"] + ROUTE_PARAM_TYPES = ["str", "int", "float", "alpha", "uuid"] def generate_random_direct_route(self, max_route_depth=4): routes = [] for depth in range(1, max_route_depth + 1): for _ in range(self.ROUTE_COUNT_PER_DEPTH): route = "/".join( - [ - TYPE_TO_GENERATOR_MAP.get("string")() - for _ in range(depth) - ] + [TYPE_TO_GENERATOR_MAP.get("str")() for _ in range(depth)] ) route = route.replace(".", "", -1) route_detail = (random.choice(self.HTTP_METHODS), route) @@ -83,7 +81,7 @@ def add_typed_parameters(self, current_routes, max_route_depth=8): new_route_part = "/".join( [ "<{}:{}>".format( - TYPE_TO_GENERATOR_MAP.get("string")(), + TYPE_TO_GENERATOR_MAP.get("str")(), random.choice(self.ROUTE_PARAM_TYPES), ) for _ in range(max_route_depth - current_length) @@ -98,7 +96,7 @@ def add_typed_parameters(self, current_routes, max_route_depth=8): def generate_url_for_template(template): url = template for pattern, param_type in re.findall( - re.compile(r"((?:<\w+:(string|int|number|alpha|uuid)>)+)"), + re.compile(r"((?:<\w+:(str|int|float|alpha|uuid)>)+)"), template, ): value = TYPE_TO_GENERATOR_MAP.get(param_type)() @@ -111,6 +109,7 @@ def sanic_router(app): # noinspection PyProtectedMember def _setup(route_details: tuple) -> Tuple[Router, tuple]: router = Router() + router.ctx.app = app added_router = [] for method, route in route_details: try: @@ -141,5 +140,33 @@ def url_param_generator(): @pytest.fixture(scope="function") def app(request): + if not CACHE: + for target, method_name in TouchUp._registry: + CACHE[method_name] = getattr(target, method_name) app = Sanic(slugify.sub("-", request.node.name)) - return app + yield app + for target, method_name in TouchUp._registry: + setattr(target, method_name, CACHE[method_name]) + + +@pytest.fixture(scope="function") +def run_startup(caplog): + def run(app): + nonlocal caplog + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + with caplog.at_level(logging.DEBUG): + server = app.create_server( + debug=True, return_asyncio_server=True, port=PORT + ) + loop._stopping = False + + _server = loop.run_until_complete(server) + + _server.close() + loop.run_until_complete(_server.wait_closed()) + app.stop() + + return caplog.record_tuples + + return run diff --git a/tests/test_app.py b/tests/test_app.py index 9598d54fa1..f222fba1a8 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -178,9 +178,6 @@ async def handler(request, ws): @patch("sanic.app.WebSocketProtocol") def test_app_websocket_parameters(websocket_protocol_mock, app): app.config.WEBSOCKET_MAX_SIZE = 44 - app.config.WEBSOCKET_MAX_QUEUE = 45 - app.config.WEBSOCKET_READ_LIMIT = 46 - app.config.WEBSOCKET_WRITE_LIMIT = 47 app.config.WEBSOCKET_PING_TIMEOUT = 48 app.config.WEBSOCKET_PING_INTERVAL = 50 @@ -197,11 +194,6 @@ async def handler(request, ws): websocket_protocol_call_args = websocket_protocol_mock.call_args ws_kwargs = websocket_protocol_call_args[1] assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE - assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE - assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT - assert ( - ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT - ) assert ( ws_kwargs["websocket_ping_timeout"] == app.config.WEBSOCKET_PING_TIMEOUT @@ -396,7 +388,7 @@ def test_app_set_attribute_warning(app): assert len(record) == 1 assert record[0].message.args[0] == ( "Setting variables on Sanic instances is deprecated " - "and will be removed in version 21.9. You should change your " + "and will be removed in version 21.12. You should change your " "Sanic instance to use instance.ctx.foo instead." ) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index c707c12a0e..3d464a4f55 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -10,7 +10,7 @@ from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable from sanic.request import Request from sanic.response import json, text -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection @pytest.fixture @@ -360,6 +360,7 @@ def _request(request): _, response = await app.asgi_client.get("/error-prone") assert response.status_code == 503 + @pytest.mark.asyncio async def test_request_exception_suppressed_by_middleware(app): @app.get("/error-prone") @@ -374,4 +375,4 @@ def forbidden(request): assert response.status_code == 403 _, response = await app.asgi_client.get("/error-prone") - assert response.status_code == 403 \ No newline at end of file + assert response.status_code == 403 diff --git a/tests/test_bad_request.py b/tests/test_bad_request.py index 140fbe8a23..7a87d919b1 100644 --- a/tests/test_bad_request.py +++ b/tests/test_bad_request.py @@ -20,4 +20,4 @@ async def _request(sanic, loop): app.run(host="127.0.0.1", port=42101, debug=False) assert lines[0] == b"HTTP/1.1 400 Bad Request\r\n" - assert b"Bad Request" in lines[-1] + assert b"Bad Request" in lines[-2] diff --git a/tests/test_blueprint_copy.py b/tests/test_blueprint_copy.py new file mode 100644 index 0000000000..033e2e2041 --- /dev/null +++ b/tests/test_blueprint_copy.py @@ -0,0 +1,70 @@ +from copy import deepcopy + +from sanic import Blueprint, Sanic, blueprints, response +from sanic.response import text + + +def test_bp_copy(app: Sanic): + bp1 = Blueprint("test_bp1", version=1) + bp1.ctx.test = 1 + assert hasattr(bp1.ctx, "test") + + @bp1.route("/page") + def handle_request(request): + return text("Hello world!") + + bp2 = bp1.copy(name="test_bp2", version=2) + assert id(bp1) != id(bp2) + assert bp1._apps == bp2._apps == set() + assert not hasattr(bp2.ctx, "test") + assert len(bp2._future_exceptions) == len(bp1._future_exceptions) + assert len(bp2._future_listeners) == len(bp1._future_listeners) + assert len(bp2._future_middleware) == len(bp1._future_middleware) + assert len(bp2._future_routes) == len(bp1._future_routes) + assert len(bp2._future_signals) == len(bp1._future_signals) + + app.blueprint(bp1) + app.blueprint(bp2) + + bp3 = bp1.copy(name="test_bp3", version=3, with_registration=True) + assert id(bp1) != id(bp3) + assert bp1._apps == bp3._apps and bp3._apps + assert not hasattr(bp3.ctx, "test") + + bp4 = bp1.copy(name="test_bp4", version=4, with_ctx=True) + assert id(bp1) != id(bp4) + assert bp4.ctx.test == 1 + + bp5 = bp1.copy(name="test_bp5", version=5, with_registration=False) + assert id(bp1) != id(bp5) + assert not bp5._apps + assert bp1._apps != set() + + app.blueprint(bp5) + + bp6 = bp1.copy( + name="test_bp6", + version=6, + with_registration=True, + version_prefix="/version", + ) + assert bp6._apps + assert bp6.version_prefix == "/version" + + _, response = app.test_client.get("/v1/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/v2/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/v3/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/v4/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/v5/page") + assert "Hello world!" in response.text + + _, response = app.test_client.get("/version6/page") + assert "Hello world!" in response.text diff --git a/tests/test_blueprint_group.py b/tests/test_blueprint_group.py index 77ddf44c5a..09729c15f5 100644 --- a/tests/test_blueprint_group.py +++ b/tests/test_blueprint_group.py @@ -3,6 +3,12 @@ from sanic.app import Sanic from sanic.blueprint_group import BlueprintGroup from sanic.blueprints import Blueprint +from sanic.exceptions import ( + Forbidden, + InvalidUsage, + SanicException, + ServerError, +) from sanic.request import Request from sanic.response import HTTPResponse, text @@ -96,16 +102,28 @@ def test_bp_group(app: Sanic): def blueprint_1_default_route(request): return text("BP1_OK") + @blueprint_1.route("/invalid") + def blueprint_1_error(request: Request): + raise InvalidUsage("Invalid") + @blueprint_2.route("/") def blueprint_2_default_route(request): return text("BP2_OK") + @blueprint_2.route("/error") + def blueprint_2_error(request: Request): + raise ServerError("Error") + blueprint_group_1 = Blueprint.group( blueprint_1, blueprint_2, url_prefix="/bp" ) blueprint_3 = Blueprint("blueprint_3", url_prefix="/bp3") + @blueprint_group_1.exception(InvalidUsage) + def handle_group_exception(request, exception): + return text("BP1_ERR_OK") + @blueprint_group_1.middleware("request") def blueprint_group_1_middleware(request): global MIDDLEWARE_INVOKE_COUNTER @@ -116,19 +134,47 @@ def blueprint_group_1_middleware_not_called(request): global MIDDLEWARE_INVOKE_COUNTER MIDDLEWARE_INVOKE_COUNTER["request"] += 1 + @blueprint_group_1.on_request + def blueprint_group_1_convenience_1(request): + global MIDDLEWARE_INVOKE_COUNTER + MIDDLEWARE_INVOKE_COUNTER["request"] += 1 + + @blueprint_group_1.on_request() + def blueprint_group_1_convenience_2(request): + global MIDDLEWARE_INVOKE_COUNTER + MIDDLEWARE_INVOKE_COUNTER["request"] += 1 + @blueprint_3.route("/") def blueprint_3_default_route(request): return text("BP3_OK") + @blueprint_3.route("/forbidden") + def blueprint_3_forbidden(request: Request): + raise Forbidden("Forbidden") + blueprint_group_2 = Blueprint.group( blueprint_group_1, blueprint_3, url_prefix="/api" ) + @blueprint_group_2.exception(SanicException) + def handle_non_handled_exception(request, exception): + return text("BP2_ERR_OK") + @blueprint_group_2.middleware("response") def blueprint_group_2_middleware(request, response): global MIDDLEWARE_INVOKE_COUNTER MIDDLEWARE_INVOKE_COUNTER["response"] += 1 + @blueprint_group_2.on_response + def blueprint_group_2_middleware_convenience_1(request, response): + global MIDDLEWARE_INVOKE_COUNTER + MIDDLEWARE_INVOKE_COUNTER["response"] += 1 + + @blueprint_group_2.on_response() + def blueprint_group_2_middleware_convenience_2(request, response): + global MIDDLEWARE_INVOKE_COUNTER + MIDDLEWARE_INVOKE_COUNTER["response"] += 1 + app.blueprint(blueprint_group_2) @app.route("/") @@ -141,14 +187,23 @@ def app_default_route(request): _, response = app.test_client.get("/api/bp/bp1") assert response.text == "BP1_OK" + _, response = app.test_client.get("/api/bp/bp1/invalid") + assert response.text == "BP1_ERR_OK" + _, response = app.test_client.get("/api/bp/bp2") assert response.text == "BP2_OK" + _, response = app.test_client.get("/api/bp/bp2/error") + assert response.text == "BP2_ERR_OK" + _, response = app.test_client.get("/api/bp3") assert response.text == "BP3_OK" - assert MIDDLEWARE_INVOKE_COUNTER["response"] == 3 - assert MIDDLEWARE_INVOKE_COUNTER["request"] == 4 + _, response = app.test_client.get("/api/bp3/forbidden") + assert response.text == "BP2_ERR_OK" + + assert MIDDLEWARE_INVOKE_COUNTER["response"] == 18 + assert MIDDLEWARE_INVOKE_COUNTER["request"] == 16 def test_bp_group_list_operations(app: Sanic): diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index fec7b50a41..b6a2315177 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -83,7 +83,6 @@ def handler(request): return text("OK") else: - print(func) raise Exception(f"{func} is not callable") app.blueprint(bp) @@ -477,6 +476,58 @@ def handler_exception(request, exception): assert response.status == 200 +def test_bp_exception_handler_applied(app): + class Error(Exception): + pass + + handled = Blueprint("handled") + nothandled = Blueprint("nothandled") + + @handled.exception(Error) + def handle_error(req, e): + return text("handled {}".format(e)) + + @handled.route("/ok") + def ok(request): + raise Error("uh oh") + + @nothandled.route("/notok") + def notok(request): + raise Error("uh oh") + + app.blueprint(handled) + app.blueprint(nothandled) + + _, response = app.test_client.get("/ok") + assert response.status == 200 + assert response.text == "handled uh oh" + + _, response = app.test_client.get("/notok") + assert response.status == 500 + + +def test_bp_exception_handler_not_applied(app): + class Error(Exception): + pass + + handled = Blueprint("handled") + nothandled = Blueprint("nothandled") + + @handled.exception(Error) + def handle_error(req, e): + return text("handled {}".format(e)) + + @nothandled.route("/notok") + def notok(request): + raise Error("uh oh") + + app.blueprint(handled) + app.blueprint(nothandled) + + _, response = app.test_client.get("/notok") + assert response.status == 500 + + def test_bp_listeners(app): app.route("/")(lambda x: x) blueprint = Blueprint("test_middleware") @@ -1034,6 +1085,6 @@ def test_bp_set_attribute_warning(): assert len(record) == 1 assert record[0].message.args[0] == ( "Setting variables on Blueprint instances is deprecated " - "and will be removed in version 21.9. You should change your " + "and will be removed in version 21.12. You should change your " "Blueprint instance to use instance.ctx.foo instead." ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5f69dd9529..908a91a3a5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -89,7 +89,7 @@ def test_debug(cmd): out, err, exitcode = capture(command) lines = out.split(b"\n") - app_info = lines[9] + app_info = lines[26] info = json.loads(app_info) assert (b"\n".join(lines[:6])).decode("utf-8") == BASE_LOGO @@ -103,7 +103,7 @@ def test_auto_reload(cmd): out, err, exitcode = capture(command) lines = out.split(b"\n") - app_info = lines[9] + app_info = lines[26] info = json.loads(app_info) assert info["debug"] is False @@ -118,7 +118,7 @@ def test_access_logs(cmd, expected): out, err, exitcode = capture(command) lines = out.split(b"\n") - app_info = lines[9] + app_info = lines[26] info = json.loads(app_info) assert info["access_log"] is expected diff --git a/tests/test_config.py b/tests/test_config.py index ce790800ce..42a7e3ecdb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,7 +13,7 @@ @contextmanager def temp_path(): - """ a simple cross platform replacement for NamedTemporaryFile """ + """a simple cross platform replacement for NamedTemporaryFile""" with TemporaryDirectory() as td: yield Path(td, "file") diff --git a/tests/test_constants.py b/tests/test_constants.py index 7ce6e4d722..2f1eb3d0ff 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -1,6 +1,4 @@ -from crypt import methods - -from sanic import text +from sanic import Sanic, text from sanic.constants import HTTP_METHODS, HTTPMethod @@ -14,7 +12,7 @@ def test_string_compat(): assert HTTPMethod.GET.upper() == "GET" -def test_use_in_routes(app): +def test_use_in_routes(app: Sanic): @app.route("/", methods=[HTTPMethod.GET, HTTPMethod.POST]) def handler(_): return text("It works") diff --git a/tests/test_create_task.py b/tests/test_create_task.py index e128263bc7..99f724b55c 100644 --- a/tests/test_create_task.py +++ b/tests/test_create_task.py @@ -1,6 +1,5 @@ import asyncio -from queue import Queue from threading import Event from sanic.response import text @@ -13,8 +12,6 @@ async def coro(): await asyncio.sleep(0.05) e.set() - app.add_task(coro) - @app.route("/early") def not_set(request): return text(str(e.is_set())) @@ -24,24 +21,30 @@ async def set(request): await asyncio.sleep(0.1) return text(str(e.is_set())) + app.add_task(coro) + request, response = app.test_client.get("/early") assert response.body == b"False" + app.signal_router.reset() + app.add_task(coro) request, response = app.test_client.get("/late") assert response.body == b"True" def test_create_task_with_app_arg(app): - q = Queue() + @app.after_server_start + async def setup_q(app, _): + app.ctx.q = asyncio.Queue() @app.route("/") - def not_set(request): - return "hello" + async def not_set(request): + return text(await request.app.ctx.q.get()) async def coro(app): - q.put(app.name) + await app.ctx.q.put(app.name) app.add_task(coro) - request, response = app.test_client.get("/") - assert q.get() == "test_create_task_with_app_arg" + _, response = app.test_client.get("/") + assert response.text == "test_create_task_with_app_arg" diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 495c764fd9..5af4ca5fe0 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -1,10 +1,10 @@ import pytest from sanic import Sanic -from sanic.errorpages import exception_response -from sanic.exceptions import NotFound +from sanic.errorpages import HTMLRenderer, exception_response +from sanic.exceptions import NotFound, SanicException from sanic.request import Request -from sanic.response import HTTPResponse +from sanic.response import HTTPResponse, html, json, text @pytest.fixture @@ -20,7 +20,7 @@ def err(request): @pytest.fixture def fake_request(app): - return Request(b"/foobar", {}, "1.1", "GET", None, app) + return Request(b"/foobar", {"accept": "*/*"}, "1.1", "GET", None, app) @pytest.mark.parametrize( @@ -47,7 +47,13 @@ def test_should_return_html_valid_setting( try: raise exception("bad stuff") except Exception as e: - response = exception_response(fake_request, e, True) + response = exception_response( + fake_request, + e, + True, + base=HTMLRenderer, + fallback=fake_request.app.config.FALLBACK_ERROR_FORMAT, + ) assert isinstance(response, HTTPResponse) assert response.status == status @@ -74,13 +80,194 @@ def test_auto_fallback_with_content_type(app): app.config.FALLBACK_ERROR_FORMAT = "auto" _, response = app.test_client.get( - "/error", headers={"content-type": "application/json"} + "/error", headers={"content-type": "application/json", "accept": "*/*"} ) assert response.status == 500 assert response.content_type == "application/json" _, response = app.test_client.get( - "/error", headers={"content-type": "text/plain"} + "/error", headers={"content-type": "foo/bar", "accept": "*/*"} + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + +def test_route_error_format_set_on_auto(app): + @app.get("/text") + def text_response(request): + return text(request.route.ctx.error_format) + + @app.get("/json") + def json_response(request): + return json({"format": request.route.ctx.error_format}) + + @app.get("/html") + def html_response(request): + return html(request.route.ctx.error_format) + + _, response = app.test_client.get("/text") + assert response.text == "text" + + _, response = app.test_client.get("/json") + assert response.json["format"] == "json" + + _, response = app.test_client.get("/html") + assert response.text == "html" + + +def test_route_error_response_from_auto_route(app): + @app.get("/text") + def text_response(request): + raise Exception("oops") + return text("Never gonna see this") + + @app.get("/json") + def json_response(request): + raise Exception("oops") + return json({"message": "Never gonna see this"}) + + @app.get("/html") + def html_response(request): + raise Exception("oops") + return html("

Never gonna see this

") + + _, response = app.test_client.get("/text") + assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get("/json") + assert response.content_type == "application/json" + + _, response = app.test_client.get("/html") + assert response.content_type == "text/html; charset=utf-8" + + +def test_route_error_response_from_explicit_format(app): + @app.get("/text", error_format="json") + def text_response(request): + raise Exception("oops") + return text("Never gonna see this") + + @app.get("/json", error_format="text") + def json_response(request): + raise Exception("oops") + return json({"message": "Never gonna see this"}) + + _, response = app.test_client.get("/text") + assert response.content_type == "application/json" + + _, response = app.test_client.get("/json") + assert response.content_type == "text/plain; charset=utf-8" + + +def test_unknown_fallback_format(app): + with pytest.raises(SanicException, match="Unknown format: bad"): + app.config.FALLBACK_ERROR_FORMAT = "bad" + + +def test_route_error_format_unknown(app): + with pytest.raises(SanicException, match="Unknown format: bad"): + + @app.get("/text", error_format="bad") + def handler(request): + ... + + +def test_fallback_with_content_type_mismatch_accept(app): + app.config.FALLBACK_ERROR_FORMAT = "auto" + + _, response = app.test_client.get( + "/error", + headers={"content-type": "application/json", "accept": "text/plain"}, ) assert response.status == 500 assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get( + "/error", + headers={"content-type": "text/plain", "accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + app.router.reset() + + @app.route("/alt1") + @app.route("/alt2", error_format="text") + @app.route("/alt3", error_format="html") + def handler(_): + raise Exception("problem here") + # Yes, we know this return value is unreachable. This is on purpose. + return json({}) + + app.router.finalize() + + _, response = app.test_client.get( + "/alt1", + headers={"accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + _, response = app.test_client.get( + "/alt1", + headers={"accept": "foo/bar,*/*"}, + ) + assert response.status == 500 + assert response.content_type == "application/json" + + _, response = app.test_client.get( + "/alt2", + headers={"accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + _, response = app.test_client.get( + "/alt2", + headers={"accept": "foo/bar,*/*"}, + ) + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + _, response = app.test_client.get( + "/alt3", + headers={"accept": "foo/bar"}, + ) + assert response.status == 500 + assert response.content_type == "text/html; charset=utf-8" + + +@pytest.mark.parametrize( + "accept,content_type,expected", + ( + (None, None, "text/plain; charset=utf-8"), + ("foo/bar", None, "text/html; charset=utf-8"), + ("application/json", None, "application/json"), + ("application/json,text/plain", None, "application/json"), + ("text/plain,application/json", None, "application/json"), + ("text/plain,foo/bar", None, "text/plain; charset=utf-8"), + # Following test is valid after v22.3 + # ("text/plain,text/html", None, "text/plain; charset=utf-8"), + ("*/*", "foo/bar", "text/html; charset=utf-8"), + ("*/*", "application/json", "application/json"), + ), +) +def test_combinations_for_auto(fake_request, accept, content_type, expected): + if accept: + fake_request.headers["accept"] = accept + else: + del fake_request.headers["accept"] + + if content_type: + fake_request.headers["content-type"] = content_type + + try: + raise Exception("bad stuff") + except Exception as e: + response = exception_response( + fake_request, + e, + True, + base=HTMLRenderer, + fallback="auto", + ) + + assert response.content_type == expected diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8487c70bd9..29797e1e1f 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,3 +1,4 @@ +import logging import warnings import pytest @@ -15,6 +16,7 @@ abort, ) from sanic.response import text +from websockets.version import version as websockets_version class SanicExceptionTestException(Exception): @@ -232,3 +234,41 @@ def test_sanic_exception(exception_app): request, response = exception_app.test_client.get("/old_abort") assert response.status == 500 assert len(w) == 1 and "deprecated" in w[0].message.args[0] + + +def test_custom_exception_default_message(exception_app): + class TeaError(SanicException): + message = "Tempest in a teapot" + status_code = 418 + + exception_app.router.reset() + + @exception_app.get("/tempest") + def tempest(_): + raise TeaError + + _, response = exception_app.test_client.get("/tempest", debug=True) + assert response.status == 418 + assert b"Tempest in a teapot" in response.body + + +def test_exception_in_ws_logged(caplog): + app = Sanic(__file__) + + @app.websocket("/feed") + async def feed(request, ws): + raise Exception("...") + + with caplog.at_level(logging.INFO): + app.test_client.websocket("/feed") + # Websockets v10.0 and above output an additional + # INFO message when a ws connection is accepted + ws_version_parts = websockets_version.split(".") + ws_major = int(ws_version_parts[0]) + record_index = 2 if ws_major >= 10 else 1 + assert caplog.record_tuples[record_index][0] == "sanic.error" + assert caplog.record_tuples[record_index][1] == logging.ERROR + assert ( + "Exception occurred while handling uri:" + in caplog.record_tuples[record_index][2] + ) diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index e6fd42eb4f..dbf9fcbb9b 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -1,5 +1,7 @@ import asyncio +import pytest + from bs4 import BeautifulSoup from sanic import Sanic @@ -8,9 +10,6 @@ from sanic.response import stream, text -exception_handler_app = Sanic("test_exception_handler") - - async def sample_streaming_fn(response): await response.write("foo,") await asyncio.sleep(0.001) @@ -21,113 +20,107 @@ class ErrorWithRequestCtx(ServerError): pass -@exception_handler_app.route("/1") -def handler_1(request): - raise InvalidUsage("OK") - - -@exception_handler_app.route("/2") -def handler_2(request): - raise ServerError("OK") - +@pytest.fixture +def exception_handler_app(): + exception_handler_app = Sanic("test_exception_handler") -@exception_handler_app.route("/3") -def handler_3(request): - raise NotFound("OK") + @exception_handler_app.route("/1", error_format="html") + def handler_1(request): + raise InvalidUsage("OK") + @exception_handler_app.route("/2", error_format="html") + def handler_2(request): + raise ServerError("OK") -@exception_handler_app.route("/4") -def handler_4(request): - foo = bar # noqa -- F821 undefined name 'bar' is done to throw exception - return text(foo) - - -@exception_handler_app.route("/5") -def handler_5(request): - class CustomServerError(ServerError): - pass - - raise CustomServerError("Custom server error") + @exception_handler_app.route("/3", error_format="html") + def handler_3(request): + raise NotFound("OK") + @exception_handler_app.route("/4", error_format="html") + def handler_4(request): + foo = bar # noqa -- F821 + return text(foo) -@exception_handler_app.route("/6/") -def handler_6(request, arg): - try: - foo = 1 / arg - except Exception as e: - raise e from ValueError(f"{arg}") - return text(foo) - - -@exception_handler_app.route("/7") -def handler_7(request): - raise Forbidden("go away!") - + @exception_handler_app.route("/5", error_format="html") + def handler_5(request): + class CustomServerError(ServerError): + pass -@exception_handler_app.route("/8") -def handler_8(request): + raise CustomServerError("Custom server error") - raise ErrorWithRequestCtx("OK") + @exception_handler_app.route("/6/", error_format="html") + def handler_6(request, arg): + try: + foo = 1 / arg + except Exception as e: + raise e from ValueError(f"{arg}") + return text(foo) + @exception_handler_app.route("/7", error_format="html") + def handler_7(request): + raise Forbidden("go away!") -@exception_handler_app.exception(ErrorWithRequestCtx, NotFound) -def handler_exception_with_ctx(request, exception): - return text(request.ctx.middleware_ran) + @exception_handler_app.route("/8", error_format="html") + def handler_8(request): + raise ErrorWithRequestCtx("OK") -@exception_handler_app.exception(ServerError) -def handler_exception(request, exception): - return text("OK") + @exception_handler_app.exception(ErrorWithRequestCtx, NotFound) + def handler_exception_with_ctx(request, exception): + return text(request.ctx.middleware_ran) + @exception_handler_app.exception(ServerError) + def handler_exception(request, exception): + return text("OK") -@exception_handler_app.exception(Forbidden) -async def async_handler_exception(request, exception): - return stream( - sample_streaming_fn, - content_type="text/csv", - ) + @exception_handler_app.exception(Forbidden) + async def async_handler_exception(request, exception): + return stream( + sample_streaming_fn, + content_type="text/csv", + ) + @exception_handler_app.middleware + async def some_request_middleware(request): + request.ctx.middleware_ran = "Done." -@exception_handler_app.middleware -async def some_request_middleware(request): - request.ctx.middleware_ran = "Done." + return exception_handler_app -def test_invalid_usage_exception_handler(): +def test_invalid_usage_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/1") assert response.status == 400 -def test_server_error_exception_handler(): +def test_server_error_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/2") assert response.status == 200 assert response.text == "OK" -def test_not_found_exception_handler(): +def test_not_found_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/3") assert response.status == 200 -def test_text_exception__handler(): +def test_text_exception__handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/random") assert response.status == 200 assert response.text == "Done." -def test_async_exception_handler(): +def test_async_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/7") assert response.status == 200 assert response.text == "foo,bar" -def test_html_traceback_output_in_debug_mode(): +def test_html_traceback_output_in_debug_mode(exception_handler_app): request, response = exception_handler_app.test_client.get("/4", debug=True) assert response.status == 500 soup = BeautifulSoup(response.body, "html.parser") html = str(soup) - assert "response = handler(request, **kwargs)" in html assert "handler_4" in html assert "foo = bar" in html @@ -137,12 +130,12 @@ def test_html_traceback_output_in_debug_mode(): ) == summary_text -def test_inherited_exception_handler(): +def test_inherited_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get("/5") assert response.status == 200 -def test_chained_exception_handler(): +def test_chained_exception_handler(exception_handler_app): request, response = exception_handler_app.test_client.get( "/6/0", debug=True ) @@ -151,11 +144,9 @@ def test_chained_exception_handler(): soup = BeautifulSoup(response.body, "html.parser") html = str(soup) - assert "response = handler(request, **kwargs)" in html assert "handler_6" in html assert "foo = 1 / arg" in html assert "ValueError" in html - assert "The above exception was the direct cause" in html summary_text = " ".join(soup.select(".summary")[0].text.split()) assert ( @@ -163,7 +154,7 @@ def test_chained_exception_handler(): ) == summary_text -def test_exception_handler_lookup(): +def test_exception_handler_lookup(exception_handler_app): class CustomError(Exception): pass @@ -186,26 +177,32 @@ def import_error_handler(): class ModuleNotFoundError(ImportError): pass - handler = ErrorHandler() + handler = ErrorHandler("auto") handler.add(ImportError, import_error_handler) handler.add(CustomError, custom_error_handler) handler.add(ServerError, server_error_handler) - assert handler.lookup(ImportError()) == import_error_handler - assert handler.lookup(ModuleNotFoundError()) == import_error_handler - assert handler.lookup(CustomError()) == custom_error_handler - assert handler.lookup(ServerError("Error")) == server_error_handler - assert handler.lookup(CustomServerError("Error")) == server_error_handler + assert handler.lookup(ImportError(), None) == import_error_handler + assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler + assert handler.lookup(CustomError(), None) == custom_error_handler + assert handler.lookup(ServerError("Error"), None) == server_error_handler + assert ( + handler.lookup(CustomServerError("Error"), None) + == server_error_handler + ) # once again to ensure there is no caching bug - assert handler.lookup(ImportError()) == import_error_handler - assert handler.lookup(ModuleNotFoundError()) == import_error_handler - assert handler.lookup(CustomError()) == custom_error_handler - assert handler.lookup(ServerError("Error")) == server_error_handler - assert handler.lookup(CustomServerError("Error")) == server_error_handler + assert handler.lookup(ImportError(), None) == import_error_handler + assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler + assert handler.lookup(CustomError(), None) == custom_error_handler + assert handler.lookup(ServerError("Error"), None) == server_error_handler + assert ( + handler.lookup(CustomServerError("Error"), None) + == server_error_handler + ) -def test_exception_handler_processed_request_middleware(): +def test_exception_handler_processed_request_middleware(exception_handler_app): request, response = exception_handler_app.test_client.get("/8") assert response.status == 200 assert response.text == "Done." diff --git a/tests/test_graceful_shutdown.py b/tests/test_graceful_shutdown.py new file mode 100644 index 0000000000..8380ed50d2 --- /dev/null +++ b/tests/test_graceful_shutdown.py @@ -0,0 +1,46 @@ +import asyncio +import logging +import time + +from collections import Counter +from multiprocessing import Process + +import httpx + + +PORT = 42101 + + +def test_no_exceptions_when_cancel_pending_request(app, caplog): + app.config.GRACEFUL_SHUTDOWN_TIMEOUT = 1 + + @app.get("/") + async def handler(request): + await asyncio.sleep(5) + + @app.after_server_start + def shutdown(app, _): + time.sleep(0.2) + app.stop() + + def ping(): + time.sleep(0.1) + response = httpx.get("http://127.0.0.1:8000") + print(response.status_code) + + p = Process(target=ping) + p.start() + + with caplog.at_level(logging.INFO): + app.run() + + p.kill() + + counter = Counter([r[1] for r in caplog.record_tuples]) + + assert counter[logging.INFO] == 5 + assert logging.ERROR not in counter + assert ( + caplog.record_tuples[3][2] + == "Request: GET http://127.0.0.1:8000/ stopped. Transport is closed." + ) diff --git a/tests/test_handler_annotations.py b/tests/test_handler_annotations.py new file mode 100644 index 0000000000..14d1d7b70d --- /dev/null +++ b/tests/test_handler_annotations.py @@ -0,0 +1,39 @@ +from uuid import UUID + +import pytest + +from sanic import json + + +@pytest.mark.parametrize( + "idx,path,expectation", + ( + (0, "/abc", "str"), + (1, "/123", "int"), + (2, "/123.5", "float"), + (3, "/8af729fe-2b94-4a95-a168-c07068568429", "UUID"), + ), +) +def test_annotated_handlers(app, idx, path, expectation): + def build_response(num, foo): + return json({"num": num, "type": type(foo).__name__}) + + @app.get("/") + def handler0(_, foo: str): + return build_response(0, foo) + + @app.get("/") + def handler1(_, foo: int): + return build_response(1, foo) + + @app.get("/") + def handler2(_, foo: float): + return build_response(2, foo) + + @app.get("/") + def handler3(_, foo: UUID): + return build_response(3, foo) + + _, response = app.test_client.get(path) + assert response.json["num"] == idx + assert response.json["type"] == expectation diff --git a/tests/test_headers.py b/tests/test_headers.py index 546a9ef7de..115bed86b3 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -3,8 +3,9 @@ import pytest from sanic import headers, text -from sanic.exceptions import PayloadTooLarge +from sanic.exceptions import InvalidHeader, PayloadTooLarge from sanic.http import Http +from sanic.request import Request @pytest.fixture @@ -182,3 +183,187 @@ def test_request_line(app): ) assert request.request_line == b"GET / HTTP/1.1" + + +@pytest.mark.parametrize( + "raw", + ( + "show/first, show/second", + "show/*, show/first", + "*/*, show/first", + "*/*, show/*", + "other/*; q=0.1, show/*; q=0.2", + "show/first; q=0.5, show/second; q=0.5", + "show/first; foo=bar, show/second; foo=bar", + "show/second, show/first; foo=bar", + "show/second; q=0.5, show/first; foo=bar; q=0.5", + "show/second; q=0.5, show/first; q=1.0", + "show/first, show/second; q=1.0", + ), +) +def test_parse_accept_ordered_okay(raw): + ordered = headers.parse_accept(raw) + expected_subtype = ( + "*" if all(q.subtype.is_wildcard for q in ordered) else "first" + ) + assert ordered[0].type_ == "show" + assert ordered[0].subtype == expected_subtype + + +@pytest.mark.parametrize( + "raw", + ( + "missing", + "missing/", + "/missing", + ), +) +def test_bad_accept(raw): + with pytest.raises(InvalidHeader): + headers.parse_accept(raw) + + +def test_empty_accept(): + assert headers.parse_accept("") == [] + + +def test_wildcard_accept_set_ok(): + accept = headers.parse_accept("*/*")[0] + assert accept.type_.is_wildcard + assert accept.subtype.is_wildcard + + accept = headers.parse_accept("foo/bar")[0] + assert not accept.type_.is_wildcard + assert not accept.subtype.is_wildcard + + +def test_accept_parsed_against_str(): + accept = headers.Accept.parse("foo/bar") + assert accept > "foo/bar; q=0.1" + + +def test_media_type_equality(): + assert headers.MediaType("foo") == headers.MediaType("foo") == "foo" + assert headers.MediaType("foo") == headers.MediaType("*") == "*" + assert headers.MediaType("foo") != headers.MediaType("bar") + assert headers.MediaType("foo") != "bar" + + +def test_media_type_matching(): + assert headers.MediaType("foo").match(headers.MediaType("foo")) + assert headers.MediaType("foo").match("foo") + + assert not headers.MediaType("foo").match(headers.MediaType("*")) + assert not headers.MediaType("foo").match("*") + + assert not headers.MediaType("foo").match(headers.MediaType("bar")) + assert not headers.MediaType("foo").match("bar") + + +@pytest.mark.parametrize( + "value,other,outcome,allow_type,allow_subtype", + ( + # ALLOW BOTH + ("foo/bar", "foo/bar", True, True, True), + ("foo/bar", headers.Accept.parse("foo/bar"), True, True, True), + ("foo/bar", "foo/*", True, True, True), + ("foo/bar", headers.Accept.parse("foo/*"), True, True, True), + ("foo/bar", "*/*", True, True, True), + ("foo/bar", headers.Accept.parse("*/*"), True, True, True), + ("foo/*", "foo/bar", True, True, True), + ("foo/*", headers.Accept.parse("foo/bar"), True, True, True), + ("foo/*", "foo/*", True, True, True), + ("foo/*", headers.Accept.parse("foo/*"), True, True, True), + ("foo/*", "*/*", True, True, True), + ("foo/*", headers.Accept.parse("*/*"), True, True, True), + ("*/*", "foo/bar", True, True, True), + ("*/*", headers.Accept.parse("foo/bar"), True, True, True), + ("*/*", "foo/*", True, True, True), + ("*/*", headers.Accept.parse("foo/*"), True, True, True), + ("*/*", "*/*", True, True, True), + ("*/*", headers.Accept.parse("*/*"), True, True, True), + # ALLOW TYPE + ("foo/bar", "foo/bar", True, True, False), + ("foo/bar", headers.Accept.parse("foo/bar"), True, True, False), + ("foo/bar", "foo/*", False, True, False), + ("foo/bar", headers.Accept.parse("foo/*"), False, True, False), + ("foo/bar", "*/*", False, True, False), + ("foo/bar", headers.Accept.parse("*/*"), False, True, False), + ("foo/*", "foo/bar", False, True, False), + ("foo/*", headers.Accept.parse("foo/bar"), False, True, False), + ("foo/*", "foo/*", False, True, False), + ("foo/*", headers.Accept.parse("foo/*"), False, True, False), + ("foo/*", "*/*", False, True, False), + ("foo/*", headers.Accept.parse("*/*"), False, True, False), + ("*/*", "foo/bar", False, True, False), + ("*/*", headers.Accept.parse("foo/bar"), False, True, False), + ("*/*", "foo/*", False, True, False), + ("*/*", headers.Accept.parse("foo/*"), False, True, False), + ("*/*", "*/*", False, True, False), + ("*/*", headers.Accept.parse("*/*"), False, True, False), + # ALLOW SUBTYPE + ("foo/bar", "foo/bar", True, False, True), + ("foo/bar", headers.Accept.parse("foo/bar"), True, False, True), + ("foo/bar", "foo/*", True, False, True), + ("foo/bar", headers.Accept.parse("foo/*"), True, False, True), + ("foo/bar", "*/*", False, False, True), + ("foo/bar", headers.Accept.parse("*/*"), False, False, True), + ("foo/*", "foo/bar", True, False, True), + ("foo/*", headers.Accept.parse("foo/bar"), True, False, True), + ("foo/*", "foo/*", True, False, True), + ("foo/*", headers.Accept.parse("foo/*"), True, False, True), + ("foo/*", "*/*", False, False, True), + ("foo/*", headers.Accept.parse("*/*"), False, False, True), + ("*/*", "foo/bar", False, False, True), + ("*/*", headers.Accept.parse("foo/bar"), False, False, True), + ("*/*", "foo/*", False, False, True), + ("*/*", headers.Accept.parse("foo/*"), False, False, True), + ("*/*", "*/*", False, False, True), + ("*/*", headers.Accept.parse("*/*"), False, False, True), + ), +) +def test_accept_matching(value, other, outcome, allow_type, allow_subtype): + assert ( + headers.Accept.parse(value).match( + other, + allow_type_wildcard=allow_type, + allow_subtype_wildcard=allow_subtype, + ) + is outcome + ) + + +@pytest.mark.parametrize("value", ("foo/bar", "foo/*", "*/*")) +def test_value_in_accept(value): + acceptable = headers.parse_accept(value) + assert "foo/bar" in acceptable + assert "foo/*" in acceptable + assert "*/*" in acceptable + + +@pytest.mark.parametrize("value", ("foo/bar", "foo/*")) +def test_value_not_in_accept(value): + acceptable = headers.parse_accept(value) + assert "no/match" not in acceptable + assert "no/*" not in acceptable + + +@pytest.mark.parametrize( + "header,expected", + ( + ( + "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", # noqa: E501 + [ + "text/html", + "application/xhtml+xml", + "image/avif", + "image/webp", + "application/xml;q=0.9", + "*/*;q=0.8", + ], + ), + ), +) +def test_browser_headers(header, expected): + request = Request(b"/", {"accept": header}, "1.1", "GET", None, None) + assert request.accept == expected diff --git a/tests/test_http.py b/tests/test_http.py new file mode 100644 index 0000000000..653857a12c --- /dev/null +++ b/tests/test_http.py @@ -0,0 +1,137 @@ +import asyncio +import json as stdjson + +from collections import namedtuple +from textwrap import dedent +from typing import AnyStr + +import pytest + +from sanic_testing.reusable import ReusableClient + +from sanic import json, text +from sanic.app import Sanic + + +PORT = 1234 + + +class RawClient: + CRLF = b"\r\n" + + def __init__(self, host: str, port: int): + self.reader = None + self.writer = None + self.host = host + self.port = port + + async def connect(self): + self.reader, self.writer = await asyncio.open_connection( + self.host, self.port + ) + + async def close(self): + self.writer.close() + await self.writer.wait_closed() + + async def send(self, message: AnyStr): + if isinstance(message, str): + msg = self._clean(message).encode("utf-8") + else: + msg = message + await self._send(msg) + + async def _send(self, message: bytes): + if not self.writer: + raise Exception("No open write stream") + self.writer.write(message) + + async def recv(self, nbytes: int = -1) -> bytes: + if not self.reader: + raise Exception("No open read stream") + return await self.reader.read(nbytes) + + def _clean(self, message: str) -> str: + return ( + dedent(message) + .lstrip("\n") + .replace("\n", self.CRLF.decode("utf-8")) + ) + + +@pytest.fixture +def test_app(app: Sanic): + app.config.KEEP_ALIVE_TIMEOUT = 1 + + @app.get("/") + async def base_handler(request): + return text("111122223333444455556666777788889999") + + @app.post("/upload", stream=True) + async def upload_handler(request): + data = [part.decode("utf-8") async for part in request.stream] + return json(data) + + return app + + +@pytest.fixture +def runner(test_app): + client = ReusableClient(test_app, port=PORT) + client.run() + yield client + client.stop() + + +@pytest.fixture +def client(runner): + client = namedtuple("Client", ("raw", "send", "recv")) + + raw = RawClient(runner.host, runner.port) + runner._run(raw.connect()) + + def send(msg): + nonlocal runner + nonlocal raw + runner._run(raw.send(msg)) + + def recv(**kwargs): + nonlocal runner + nonlocal raw + method = raw.recv_until if "until" in kwargs else raw.recv + return runner._run(method(**kwargs)) + + yield client(raw, send, recv) + + runner._run(raw.close()) + + +def test_full_message(client): + client.send( + """ + GET / HTTP/1.1 + host: localhost:7777 + + """ + ) + response = client.recv() + assert len(response) == 140 + assert b"200 OK" in response + + +def test_transfer_chunked(client): + client.send( + """ + POST /upload HTTP/1.1 + transfer-encoding: chunked + + """ + ) + client.send(b"3\r\nfoo\r\n") + client.send(b"3\r\nbar\r\n") + client.send(b"0\r\n\r\n") + response = client.recv() + _, body = response.rsplit(b"\r\n\r\n", 1) + data = stdjson.loads(body) + + assert data == ["foo", "bar"] diff --git a/tests/test_keep_alive_timeout.py b/tests/test_keep_alive_timeout.py index e777de2ead..e30761ed3d 100644 --- a/tests/test_keep_alive_timeout.py +++ b/tests/test_keep_alive_timeout.py @@ -2,16 +2,13 @@ import platform from asyncio import sleep as aio_sleep -from json import JSONDecodeError from os import environ -import httpcore -import httpx import pytest -from sanic_testing.testing import HOST, SanicTestClient +from sanic_testing.reusable import ReusableClient -from sanic import Sanic, server +from sanic import Sanic from sanic.compat import OS_IS_WINDOWS from sanic.response import text @@ -21,164 +18,6 @@ PORT = 42101 # test_keep_alive_timeout_reuse doesn't work with random port -class ReusableSanicConnectionPool(httpcore.AsyncConnectionPool): - last_reused_connection = None - - async def _get_connection_from_pool(self, *args, **kwargs): - conn = await super()._get_connection_from_pool(*args, **kwargs) - self.__class__.last_reused_connection = conn - return conn - - -class ResusableSanicSession(httpx.AsyncClient): - def __init__(self, *args, **kwargs) -> None: - transport = ReusableSanicConnectionPool() - super().__init__(transport=transport, *args, **kwargs) - - -class ReuseableSanicTestClient(SanicTestClient): - def __init__(self, app, loop=None): - super().__init__(app) - if loop is None: - loop = asyncio.get_event_loop() - self._loop = loop - self._server = None - self._tcp_connector = None - self._session = None - - def get_new_session(self): - return ResusableSanicSession() - - # Copied from SanicTestClient, but with some changes to reuse the - # same loop for the same app. - def _sanic_endpoint_test( - self, - method="get", - uri="/", - gather_request=True, - debug=False, - server_kwargs=None, - *request_args, - **request_kwargs, - ): - loop = self._loop - results = [None, None] - exceptions = [] - server_kwargs = server_kwargs or {"return_asyncio_server": True} - if gather_request: - - def _collect_request(request): - if results[0] is None: - results[0] = request - - self.app.request_middleware.appendleft(_collect_request) - - if uri.startswith( - ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") - ): - url = uri - else: - uri = uri if uri.startswith("/") else f"/{uri}" - scheme = "http" - url = f"{scheme}://{HOST}:{PORT}{uri}" - - @self.app.listener("after_server_start") - async def _collect_response(loop): - try: - response = await self._local_request( - method, url, *request_args, **request_kwargs - ) - results[-1] = response - except Exception as e2: - exceptions.append(e2) - - if self._server is not None: - _server = self._server - else: - _server_co = self.app.create_server( - host=HOST, debug=debug, port=PORT, **server_kwargs - ) - - server.trigger_events( - self.app.listeners["before_server_start"], loop - ) - - try: - loop._stopping = False - _server = loop.run_until_complete(_server_co) - except Exception as e1: - raise e1 - self._server = _server - server.trigger_events(self.app.listeners["after_server_start"], loop) - self.app.listeners["after_server_start"].pop() - - if exceptions: - raise ValueError(f"Exception during request: {exceptions}") - - if gather_request: - self.app.request_middleware.pop() - try: - request, response = results - return request, response - except Exception: - raise ValueError( - f"Request and response object expected, got ({results})" - ) - else: - try: - return results[-1] - except Exception: - raise ValueError(f"Request object expected, got ({results})") - - def kill_server(self): - try: - if self._server: - self._server.close() - self._loop.run_until_complete(self._server.wait_closed()) - self._server = None - - if self._session: - self._loop.run_until_complete(self._session.aclose()) - self._session = None - - except Exception as e3: - raise e3 - - # Copied from SanicTestClient, but with some changes to reuse the - # same TCPConnection and the sane ClientSession more than once. - # Note, you cannot use the same session if you are in a _different_ - # loop, so the changes above are required too. - async def _local_request(self, method, url, *args, **kwargs): - raw_cookies = kwargs.pop("raw_cookies", None) - request_keepalive = kwargs.pop( - "request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"] - ) - if not self._session: - self._session = self.get_new_session() - try: - response = await getattr(self._session, method.lower())( - url, timeout=request_keepalive, *args, **kwargs - ) - except NameError: - raise Exception(response.status_code) - - try: - response.json = response.json() - except (JSONDecodeError, UnicodeDecodeError): - response.json = None - - response.body = await response.aread() - response.status = response.status_code - response.content_type = response.headers.get("content-type") - - if raw_cookies: - response.raw_cookies = {} - for cookie in response.cookies: - response.raw_cookies[cookie.name] = cookie - - return response - - keep_alive_timeout_app_reuse = Sanic("test_ka_timeout_reuse") keep_alive_app_client_timeout = Sanic("test_ka_client_timeout") keep_alive_app_server_timeout = Sanic("test_ka_server_timeout") @@ -224,21 +63,22 @@ def test_keep_alive_timeout_reuse(): """If the server keep-alive timeout and client keep-alive timeout are both longer than the delay, the client _and_ server will successfully reuse the existing connection.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_timeout_app_reuse, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient(keep_alive_timeout_app_reuse, loop=loop, port=PORT) + with client: headers = {"Connection": "keep-alive"} request, response = client.get("/1", headers=headers) assert response.status == 200 assert response.text == "OK" + assert request.protocol.state["requests_count"] == 1 + loop.run_until_complete(aio_sleep(1)) + request, response = client.get("/1") assert response.status == 200 assert response.text == "OK" - assert ReusableSanicConnectionPool.last_reused_connection - finally: - client.kill_server() + assert request.protocol.state["requests_count"] == 2 @pytest.mark.skipif( @@ -250,22 +90,22 @@ def test_keep_alive_timeout_reuse(): def test_keep_alive_client_timeout(): """If the server keep-alive timeout is longer than the client keep-alive timeout, client will try to create a new connection here.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_app_client_timeout, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient( + keep_alive_app_client_timeout, loop=loop, port=PORT + ) + with client: headers = {"Connection": "keep-alive"} - _, response = client.get("/1", headers=headers, request_keepalive=1) + request, response = client.get("/1", headers=headers, timeout=1) assert response.status == 200 assert response.text == "OK" + assert request.protocol.state["requests_count"] == 1 loop.run_until_complete(aio_sleep(2)) - _, response = client.get("/1", request_keepalive=1) - - assert ReusableSanicConnectionPool.last_reused_connection is None - finally: - client.kill_server() + request, response = client.get("/1", timeout=1) + assert request.protocol.state["requests_count"] == 1 @pytest.mark.skipif( @@ -277,22 +117,23 @@ def test_keep_alive_server_timeout(): keep-alive timeout, the client will either a 'Connection reset' error _or_ a new connection. Depending on how the event-loop handles the broken server connection.""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_app_server_timeout, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient( + keep_alive_app_server_timeout, loop=loop, port=PORT + ) + with client: headers = {"Connection": "keep-alive"} - _, response = client.get("/1", headers=headers, request_keepalive=60) + request, response = client.get("/1", headers=headers, timeout=60) assert response.status == 200 assert response.text == "OK" + assert request.protocol.state["requests_count"] == 1 loop.run_until_complete(aio_sleep(3)) - _, response = client.get("/1", request_keepalive=60) + request, response = client.get("/1", timeout=60) - assert ReusableSanicConnectionPool.last_reused_connection is None - finally: - client.kill_server() + assert request.protocol.state["requests_count"] == 1 @pytest.mark.skipif( @@ -300,10 +141,10 @@ def test_keep_alive_server_timeout(): reason="Not testable with current client", ) def test_keep_alive_connection_context(): - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = ReuseableSanicTestClient(keep_alive_app_context, loop) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = ReusableClient(keep_alive_app_context, loop=loop, port=PORT) + with client: headers = {"Connection": "keep-alive"} request1, _ = client.post("/ctx", headers=headers) @@ -315,5 +156,4 @@ def test_keep_alive_connection_context(): assert ( request1.conn_info.ctx.foo == request2.conn_info.ctx.foo == "hello" ) - finally: - client.kill_server() + assert request2.protocol.state["requests_count"] == 2 diff --git a/tests/test_logging.py b/tests/test_logging.py index 5f53167081..639bb2ee6f 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -5,6 +5,7 @@ from importlib import reload from io import StringIO +from unittest.mock import Mock import pytest @@ -51,7 +52,7 @@ def handler(request): def test_logging_defaults(): # reset_logging() - app = Sanic("test_logging") + Sanic("test_logging") for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: assert ( @@ -87,7 +88,7 @@ def test_logging_pass_customer_logconfig(): "format" ] = "%(asctime)s - (%(name)s)[%(levelname)s]: %(message)s" - app = Sanic("test_logging", log_config=modified_config) + Sanic("test_logging", log_config=modified_config) for fmt in [h.formatter for h in logging.getLogger("sanic.root").handlers]: assert fmt._fmt == modified_config["formatters"]["generic"]["format"] @@ -111,11 +112,13 @@ def test_logging_pass_customer_logconfig(): ), ) def test_log_connection_lost(app, debug, monkeypatch): - """ Should not log Connection lost exception on non debug """ + """Should not log Connection lost exception on non debug""" stream = StringIO() error = logging.getLogger("sanic.error") error.addHandler(logging.StreamHandler(stream)) - monkeypatch.setattr(sanic.server, "error_logger", error) + monkeypatch.setattr( + sanic.server.protocols.http_protocol, "error_logger", error + ) @app.route("/conn_lost") async def conn_lost(request): @@ -208,6 +211,56 @@ def test_logging_modified_root_logger_config(): modified_config = LOGGING_CONFIG_DEFAULTS modified_config["loggers"]["sanic.root"]["level"] = "DEBUG" - app = Sanic("test_logging", log_config=modified_config) + Sanic("test_logging", log_config=modified_config) assert logging.getLogger("sanic.root").getEffectiveLevel() == logging.DEBUG + + +def test_access_log_client_ip_remote_addr(monkeypatch): + access = Mock() + monkeypatch.setattr(sanic.http, "access_logger", access) + + app = Sanic("test_logging") + app.config.PROXIES_COUNT = 2 + + @app.route("/") + async def handler(request): + return text(request.remote_addr) + + headers = {"X-Forwarded-For": "1.1.1.1, 2.2.2.2"} + + request, response = app.test_client.get("/", headers=headers) + + assert request.remote_addr == "1.1.1.1" + access.info.assert_called_with( + "", + extra={ + "status": 200, + "byte": len(response.content), + "host": f"{request.remote_addr}:{request.port}", + "request": f"GET {request.scheme}://{request.host}/", + }, + ) + + +def test_access_log_client_ip_reqip(monkeypatch): + access = Mock() + monkeypatch.setattr(sanic.http, "access_logger", access) + + app = Sanic("test_logging") + + @app.route("/") + async def handler(request): + return text(request.ip) + + request, response = app.test_client.get("/") + + access.info.assert_called_with( + "", + extra={ + "status": 200, + "byte": len(response.content), + "host": f"{request.ip}:{request.port}", + "request": f"GET {request.scheme}://{request.host}/", + }, + ) diff --git a/tests/test_logo.py b/tests/test_logo.py index 3fff32db30..e59975c344 100644 --- a/tests/test_logo.py +++ b/tests/test_logo.py @@ -6,85 +6,37 @@ from sanic.config import BASE_LOGO -def test_logo_base(app, caplog): - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False +def test_logo_base(app, run_startup): + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) + assert logs[0][1] == logging.DEBUG + assert logs[0][2] == BASE_LOGO - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - assert caplog.record_tuples[0][1] == logging.DEBUG - assert caplog.record_tuples[0][2] == BASE_LOGO - - -def test_logo_false(app, caplog): +def test_logo_false(app, caplog, run_startup): app.config.LOGO = False - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False - - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() + logs = run_startup(app) - banner, port = caplog.record_tuples[0][2].rsplit(":", 1) - assert caplog.record_tuples[0][1] == logging.INFO + banner, port = logs[0][2].rsplit(":", 1) + assert logs[0][1] == logging.INFO assert banner == "Goin' Fast @ http://127.0.0.1" assert int(port) > 0 -def test_logo_true(app, caplog): +def test_logo_true(app, run_startup): app.config.LOGO = True - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False + logs = run_startup(app) - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) + assert logs[0][1] == logging.DEBUG + assert logs[0][2] == BASE_LOGO - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() - assert caplog.record_tuples[0][1] == logging.DEBUG - assert caplog.record_tuples[0][2] == BASE_LOGO - - -def test_logo_custom(app, caplog): +def test_logo_custom(app, run_startup): app.config.LOGO = "My Custom Logo" - server = app.create_server( - debug=True, return_asyncio_server=True, port=PORT - ) - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop._stopping = False - - with caplog.at_level(logging.DEBUG): - _server = loop.run_until_complete(server) - - _server.close() - loop.run_until_complete(_server.wait_closed()) - app.stop() + logs = run_startup(app) - assert caplog.record_tuples[0][1] == logging.DEBUG - assert caplog.record_tuples[0][2] == "My Custom Logo" + assert logs[0][1] == logging.DEBUG + assert logs[0][2] == "My Custom Logo" diff --git a/tests/test_middleware.py b/tests/test_middleware.py index cc7edae2f2..c19386e7e4 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -5,7 +5,7 @@ from sanic.exceptions import NotFound from sanic.request import Request -from sanic.response import HTTPResponse, text +from sanic.response import HTTPResponse, json, text # ------------------------------------------------------------ # @@ -37,14 +37,19 @@ def test_middleware_request_as_convenience(app): async def handler1(request): results.append(request) - @app.route("/") + @app.on_request() async def handler2(request): + results.append(request) + + @app.route("/") + async def handler3(request): return text("OK") request, response = app.test_client.get("/") assert response.text == "OK" assert type(results[0]) is Request + assert type(results[1]) is Request def test_middleware_response(app): @@ -79,7 +84,12 @@ async def process_request(request): results.append(request) @app.on_response - async def process_response(request, response): + async def process_response_1(request, response): + results.append(request) + results.append(response) + + @app.on_response() + async def process_response_2(request, response): results.append(request) results.append(response) @@ -93,6 +103,8 @@ async def handler(request): assert type(results[0]) is Request assert type(results[1]) is Request assert isinstance(results[2], HTTPResponse) + assert type(results[3]) is Request + assert isinstance(results[4], HTTPResponse) def test_middleware_response_as_convenience_called(app): @@ -271,3 +283,17 @@ async def handler(request): request, response = app.test_client.get("/") assert next(i) == 3 + + +def test_middleware_added_response(app): + @app.on_response + def display(_, response): + response["foo"] = "bar" + return json(response) + + @app.get("/") + async def handler(request): + return {} + + _, response = app.test_client.get("/") + assert response.json["foo"] == "bar" diff --git a/tests/test_request.py b/tests/test_request.py index e4b21f661b..ca2c1e4a75 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -140,3 +140,39 @@ async def get(request): assert resp.json["client"] == "[::1]" assert resp.json["client_ip"] == "::1" assert request.ip == "::1" + + +def test_request_accept(): + app = Sanic("req-generator") + + @app.get("/") + async def get(request): + return response.empty() + + request, _ = app.test_client.get( + "/", + headers={ + "Accept": "text/*, text/plain, text/plain;format=flowed, */*" + }, + ) + assert request.accept == [ + "text/plain;format=flowed", + "text/plain", + "text/*", + "*/*", + ] + + request, _ = app.test_client.get( + "/", + headers={ + "Accept": ( + "text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c" + ) + }, + ) + assert request.accept == [ + "text/html", + "text/x-c", + "text/x-dvi; q=0.8", + "text/plain; q=0.5", + ] diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py index 89cb46dfc3..48e23f1d63 100644 --- a/tests/test_request_timeout.py +++ b/tests/test_request_timeout.py @@ -2,6 +2,7 @@ import httpcore import httpx +import pytest from sanic_testing.testing import SanicTestClient @@ -48,42 +49,51 @@ def get_new_session(self): return DelayableSanicSession(request_delay=self._request_delay) -request_timeout_default_app = Sanic("test_request_timeout_default") -request_no_timeout_app = Sanic("test_request_no_timeout") -request_timeout_default_app.config.REQUEST_TIMEOUT = 0.6 -request_no_timeout_app.config.REQUEST_TIMEOUT = 0.6 +@pytest.fixture +def request_no_timeout_app(): + app = Sanic("test_request_no_timeout") + app.config.REQUEST_TIMEOUT = 0.6 + @app.route("/1") + async def handler2(request): + return text("OK") -@request_timeout_default_app.route("/1") -async def handler1(request): - return text("OK") + return app -@request_no_timeout_app.route("/1") -async def handler2(request): - return text("OK") +@pytest.fixture +def request_timeout_default_app(): + app = Sanic("test_request_timeout_default") + app.config.REQUEST_TIMEOUT = 0.6 + @app.route("/1") + async def handler1(request): + return text("OK") -@request_timeout_default_app.websocket("/ws1") -async def ws_handler1(request, ws): - await ws.send("OK") + @app.websocket("/ws1") + async def ws_handler1(request, ws): + await ws.send("OK") + return app -def test_default_server_error_request_timeout(): + +def test_default_server_error_request_timeout(request_timeout_default_app): client = DelayableSanicTestClient(request_timeout_default_app, 2) - request, response = client.get("/1") + _, response = client.get("/1") assert response.status == 408 assert "Request Timeout" in response.text -def test_default_server_error_request_dont_timeout(): +def test_default_server_error_request_dont_timeout(request_no_timeout_app): client = DelayableSanicTestClient(request_no_timeout_app, 0.2) - request, response = client.get("/1") + _, response = client.get("/1") assert response.status == 200 assert response.text == "OK" -def test_default_server_error_websocket_request_timeout(): +def test_default_server_error_websocket_request_timeout( + request_timeout_default_app, +): headers = { "Upgrade": "websocket", @@ -93,7 +103,7 @@ def test_default_server_error_websocket_request_timeout(): } client = DelayableSanicTestClient(request_timeout_default_app, 2) - request, response = client.get("/ws1", headers=headers) + _, response = client.get("/ws1", headers=headers) assert response.status == 408 assert "Request Timeout" in response.text diff --git a/tests/test_routes.py b/tests/test_routes.py index 0f4980f646..520ab5be1f 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -654,41 +654,46 @@ async def handler(): @pytest.mark.asyncio @pytest.mark.parametrize("url", ["/ws", "ws"]) async def test_websocket_route_asgi(app, url): - ev = asyncio.Event() + @app.after_server_start + async def setup_ev(app, _): + app.ctx.ev = asyncio.Event() @app.websocket(url) async def handler(request, ws): - ev.set() + request.app.ctx.ev.set() - request, response = await app.asgi_client.websocket(url) - assert ev.is_set() + @app.get("/ev") + async def check(request): + return json({"set": request.app.ctx.ev.is_set()}) + + _, response = await app.asgi_client.websocket(url) + _, response = await app.asgi_client.get("/") + assert response.json["set"] -def test_websocket_route_with_subprotocols(app): +@pytest.mark.parametrize( + "subprotocols,expected", + ( + (["one"], "one"), + (["three", "one"], "one"), + (["tree"], None), + (None, None), + ), +) +def test_websocket_route_with_subprotocols(app, subprotocols, expected): results = [] - @app.websocket("/ws", subprotocols=["foo", "bar"]) + @app.websocket("/ws", subprotocols=["zero", "one", "two", "three"]) async def handler(request, ws): - results.append(ws.subprotocol) + nonlocal results + results = ws.subprotocol assert ws.subprotocol is not None - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"]) - assert response.opened is True - assert results == ["bar"] - _, response = SanicTestClient(app).websocket( - "/ws", subprotocols=["bar", "foo"] + "/ws", subprotocols=subprotocols ) assert response.opened is True - assert results == ["bar", "bar"] - - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"]) - assert response.opened is True - assert results == ["bar", "bar", None] - - _, response = SanicTestClient(app).websocket("/ws") - assert response.opened is True - assert results == ["bar", "bar", None, None] + assert results == expected @pytest.mark.parametrize("strict_slashes", [True, False, None]) diff --git a/tests/test_server_events.py b/tests/test_server_events.py index 2e48f4082c..7ce1859ca3 100644 --- a/tests/test_server_events.py +++ b/tests/test_server_events.py @@ -8,7 +8,7 @@ from sanic_testing.testing import HOST, PORT -from sanic.exceptions import InvalidUsage +from sanic.exceptions import InvalidUsage, SanicException AVAILABLE_LISTENERS = [ @@ -103,7 +103,11 @@ class MySanicDb: async def init_db(app, loop): app.db = MySanicDb() - await app.create_server(debug=True, return_asyncio_server=True, port=PORT) + srv = await app.create_server( + debug=True, return_asyncio_server=True, port=PORT + ) + await srv.startup() + await srv.before_start() assert hasattr(app, "db") assert isinstance(app.db, MySanicDb) @@ -157,14 +161,15 @@ async def after_stop(app, loop): serv_coro = app.create_server(return_asyncio_server=True, sock=sock) serv_task = asyncio.ensure_future(serv_coro, loop=loop) server = loop.run_until_complete(serv_task) - server.after_start() + loop.run_until_complete(server.startup()) + loop.run_until_complete(server.after_start()) try: loop.run_forever() - except KeyboardInterrupt as e: + except KeyboardInterrupt: loop.stop() finally: # Run the on_stop function if provided - server.before_stop() + loop.run_until_complete(server.before_stop()) # Wait for server to close close_task = server.close() @@ -174,5 +179,19 @@ async def after_stop(app, loop): signal.stopped = True for connection in server.connections: connection.close_if_idle() - server.after_stop() + loop.run_until_complete(server.after_stop()) assert flag1 and flag2 and flag3 + + +@pytest.mark.asyncio +async def test_missing_startup_raises_exception(app): + @app.listener("before_server_start") + async def init_db(app, loop): + ... + + srv = await app.create_server( + debug=True, return_asyncio_server=True, port=PORT + ) + + with pytest.raises(SanicException): + await srv.before_start() diff --git a/tests/test_signal_handlers.py b/tests/test_signal_handlers.py index 857b528348..f7657ad64b 100644 --- a/tests/test_signal_handlers.py +++ b/tests/test_signal_handlers.py @@ -95,7 +95,7 @@ async def atest(stop_first): os.kill(os.getpid(), signal.SIGINT) await asyncio.sleep(0.2) assert app.is_stopping - assert app.stay_active_task.result() == None + assert app.stay_active_task.result() is None # Second Ctrl+C should raise with pytest.raises(KeyboardInterrupt): os.kill(os.getpid(), signal.SIGINT) diff --git a/tests/test_signals.py b/tests/test_signals.py index 5d116f90b7..9b8a94953a 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -68,6 +68,7 @@ async def async_signal(*_): app.signal_router.finalize() + assert len(app.signal_router.routes) == 3 await app.dispatch("foo.bar.baz") assert counter == 2 @@ -331,7 +332,8 @@ def bp_signal(): "event,expected", ( ("foo.bar.baz", True), - ("server.init.before", False), + ("server.init.before", True), + ("server.init.somethingelse", False), ("http.request.start", False), ("sanic.notice.anything", True), ), diff --git a/tests/test_static.py b/tests/test_static.py index 00e5611d80..7d62d2d34d 100644 --- a/tests/test_static.py +++ b/tests/test_static.py @@ -461,6 +461,22 @@ def test_nested_dir(app, static_file_directory): assert response.text == "foo\n" +def test_handle_is_a_directory_error(app, static_file_directory): + error_text = "Is a directory. Access denied" + app.static("/static", static_file_directory) + + @app.exception(Exception) + async def handleStaticDirError(request, exception): + if isinstance(exception, IsADirectoryError): + return text(error_text, status=403) + raise exception + + request, response = app.test_client.get("/static/") + + assert response.status == 403 + assert response.text == error_text + + def test_stack_trace_on_not_found(app, static_file_directory, caplog): app.static("/static", static_file_directory) @@ -507,3 +523,56 @@ def test_multiple_statics(app, static_file_directory): assert response.body == get_file_content( static_file_directory, "python.png" ) + + +def test_resource_type_default(app, static_file_directory): + app.static("/static", static_file_directory) + app.static("/file", get_file_path(static_file_directory, "test.file")) + + _, response = app.test_client.get("/static") + assert response.status == 404 + + _, response = app.test_client.get("/file") + assert response.status == 200 + assert response.body == get_file_content( + static_file_directory, "test.file" + ) + + +def test_resource_type_file(app, static_file_directory): + app.static( + "/file", + get_file_path(static_file_directory, "test.file"), + resource_type="file", + ) + + _, response = app.test_client.get("/file") + assert response.status == 200 + assert response.body == get_file_content( + static_file_directory, "test.file" + ) + + with pytest.raises(TypeError): + app.static("/static", static_file_directory, resource_type="file") + + +def test_resource_type_dir(app, static_file_directory): + app.static("/static", static_file_directory, resource_type="dir") + + _, response = app.test_client.get("/static/test.file") + assert response.status == 200 + assert response.body == get_file_content( + static_file_directory, "test.file" + ) + + with pytest.raises(TypeError): + app.static( + "/file", + get_file_path(static_file_directory, "test.file"), + resource_type="dir", + ) + + +def test_resource_type_unknown(app, static_file_directory, caplog): + with pytest.raises(ValueError): + app.static("/static", static_file_directory, resource_type="unknown") diff --git a/tests/test_touchup.py b/tests/test_touchup.py new file mode 100644 index 0000000000..3079aa1ba7 --- /dev/null +++ b/tests/test_touchup.py @@ -0,0 +1,21 @@ +import logging + +from sanic.signals import RESERVED_NAMESPACES +from sanic.touchup import TouchUp + + +def test_touchup_methods(app): + assert len(TouchUp._registry) == 9 + + +async def test_ode_removes_dispatch_events(app, caplog): + with caplog.at_level(logging.DEBUG, logger="sanic.root"): + await app._startup() + logs = caplog.record_tuples + + for signal in RESERVED_NAMESPACES["http"]: + assert ( + "sanic.root", + logging.DEBUG, + f"Disabling event: {signal}", + ) in logs diff --git a/tests/test_url_for.py b/tests/test_url_for.py index d623cc4ad3..6ec6a93f7b 100644 --- a/tests/test_url_for.py +++ b/tests/test_url_for.py @@ -43,7 +43,15 @@ def index(request): ) -def test_websocket_bp_route_name(app): +@pytest.mark.parametrize( + "name,expected", + ( + ("test_route", "/bp/route"), + ("test_route2", "/bp/route2"), + ("foobar_3", "/bp/route3"), + ), +) +def test_websocket_bp_route_name(app, name, expected): """Tests that blueprint websocket route is named.""" event = asyncio.Event() bp = Blueprint("test_bp", url_prefix="/bp") @@ -69,22 +77,12 @@ async def test_route3(request, ws): uri = app.url_for("test_bp.main") assert uri == "/bp/main" - uri = app.url_for("test_bp.test_route") - assert uri == "/bp/route" + uri = app.url_for(f"test_bp.{name}") + assert uri == expected request, response = SanicTestClient(app).websocket(uri) assert response.opened is True assert event.is_set() - event.clear() - uri = app.url_for("test_bp.test_route2") - assert uri == "/bp/route2" - request, response = SanicTestClient(app).websocket(uri) - assert response.opened is True - assert event.is_set() - - uri = app.url_for("test_bp.foobar_3") - assert uri == "/bp/route3" - # TODO: add test with a route with multiple hosts # TODO: add test with a route with _host in url_for diff --git a/tests/test_worker.py b/tests/test_worker.py index 252bdb3662..3850b8a691 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -175,7 +175,7 @@ def test_worker_close(worker): worker.wsgi = mock.Mock() conn = mock.Mock() conn.websocket = mock.Mock() - conn.websocket.close_connection = mock.Mock(wraps=_a_noop) + conn.websocket.fail_connection = mock.Mock(wraps=_a_noop) worker.connections = set([conn]) worker.log = mock.Mock() worker.loop = loop @@ -190,5 +190,5 @@ def test_worker_close(worker): loop.run_until_complete(_close) assert worker.signal.stopped - assert conn.websocket.close_connection.called + assert conn.websocket.fail_connection.called assert len(worker.servers) == 0 diff --git a/tox.ini b/tox.ini index 590dc25aba..5612f6ded7 100644 --- a/tox.ini +++ b/tox.ini @@ -2,53 +2,28 @@ envlist = py37, py38, py39, pyNightly, pypy37, {py37,py38,py39,pyNightly,pypy37}-no-ext, lint, check, security, docs, type-checking [testenv] -usedevelop = True +usedevelop = true setenv = {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UJSON=1 {py37,py38,py39,pyNightly}-no-ext: SANIC_NO_UVLOOP=1 -deps = - sanic-testing>=0.6.0 - coverage==5.3 - pytest==5.2.1 - pytest-cov - pytest-sanic - pytest-sugar - pytest-benchmark - chardet==3.* - beautifulsoup4 - gunicorn==20.0.4 - uvicorn - websockets>=9.0 +extras = test commands = pytest {posargs:tests --cov sanic} - coverage combine --append - coverage report -m + coverage report -m -i coverage html -i [testenv:lint] -deps = - flake8 - black - isort>=5.0.0 - bandit - commands = flake8 sanic black --config ./.black.toml --check --verbose sanic/ isort --check-only sanic --profile=black [testenv:type-checking] -deps = - mypy>=0.901 - types-ujson - commands = mypy sanic [testenv:check] -deps = - docutils - pygments commands = python setup.py check -r -s @@ -60,8 +35,6 @@ markers = asyncio [testenv:security] -deps = - bandit commands = bandit --recursive sanic --skip B404,B101 --exclude sanic/reloader_helpers.py @@ -69,30 +42,10 @@ commands = [testenv:docs] platform = linux|linux2|darwin whitelist_externals = make -deps = - sphinx>=2.1.2 - sphinx_rtd_theme>=0.4.3 - recommonmark>=0.5.0 - docutils - pygments - gunicorn==20.0.4 +extras = docs commands = make docs-test [testenv:coverage] -usedevelop = True -deps = - sanic-testing>=0.6.0 - coverage==5.3 - pytest==5.2.1 - pytest-cov - pytest-sanic - pytest-sugar - pytest-benchmark - chardet==3.* - beautifulsoup4 - gunicorn==20.0.4 - uvicorn - websockets>=9.0 commands = pytest tests --cov=./sanic --cov-report=xml