Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blueprint specific exception handlers #2208

Merged
merged 10 commits into from
Aug 31, 2021
10 changes: 7 additions & 3 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,11 @@ def register_named_middleware(
self.named_response_middleware[_rn].appendleft(middleware)
return middleware

def _apply_exception_handler(self, handler: FutureException):
def _apply_exception_handler(
self,
handler: FutureException,
route_names: Optional[List[str]] = None,
):
"""Decorate a function to be registered as a handler for exceptions

:param exceptions: exceptions
Expand All @@ -344,9 +348,9 @@ def _apply_exception_handler(self, handler: FutureException):
for exception in handler.exceptions:
if isinstance(exception, (tuple, list)):
for e in exception:
self.error_handler.add(e, handler.handler)
self.error_handler.add(e, handler.handler, route_names)
else:
self.error_handler.add(exception, handler.handler)
self.error_handler.add(exception, handler.handler, route_names)
return handler.handler

def _apply_listener(self, listener: FutureListener):
Expand Down
4 changes: 3 additions & 1 deletion sanic/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ def register(self, app, options):

# Exceptions
for future in self._future_exceptions:
exception_handlers.append(app._apply_exception_handler(future))
exception_handlers.append(
app._apply_exception_handler(future, route_names)
)

# Event listeners
for listener in self._future_listeners:
Expand Down
48 changes: 31 additions & 17 deletions sanic/handlers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Optional

from sanic.errorpages import exception_response
from sanic.exceptions import (
ContentRangeError,
Expand All @@ -21,15 +23,12 @@ class ErrorHandler:

"""

handlers = None
cached_handlers = None

def __init__(self):
self.handlers = []
self.cached_handlers = {}
self.debug = False

def add(self, exception, handler):
def add(self, exception, handler, route_names: Optional[List[str]] = None):
"""
Add a new exception handler to an already existing handler object.

Expand All @@ -42,11 +41,16 @@ def add(self, exception, handler):

:return: None
"""
# self.handlers to be deprecated and removed in version 21.12
# self.handlers is deprecated and will be removed in version 22.3
self.handlers.append((exception, handler))
self.cached_handlers[exception] = handler

def lookup(self, exception):
if route_names:
for route in route_names:
self.cached_handlers[(exception, route)] = handler
else:
self.cached_handlers[(exception, None)] = handler

def lookup(self, exception, route_name: Optional[str]):
"""
Lookup the existing instance of :class:`ErrorHandler` and fetch the
registered handler for a specific type of exception.
Expand All @@ -61,17 +65,26 @@ def lookup(self, exception):
:return: Registered function if found ``None`` otherwise
"""
exception_class = type(exception)
if exception_class in self.cached_handlers:
return self.cached_handlers[exception_class]

for ancestor in type.mro(exception_class):
if ancestor in self.cached_handlers:
handler = self.cached_handlers[ancestor]
self.cached_handlers[exception_class] = handler
for name in (route_name, None):
exception_key = (exception_class, name)
handler = self.cached_handlers.get(exception_key)
if handler:
return handler
if ancestor is BaseException:
break
self.cached_handlers[exception_class] = None

for name in (route_name, None):
for ancestor in type.mro(exception_class):
exception_key = (ancestor, name)
if exception_key in self.cached_handlers:
handler = self.cached_handlers[exception_key]
self.cached_handlers[
(exception_class, route_name)
] = handler
return handler

if ancestor is BaseException:
break
self.cached_handlers[(exception_class, route_name)] = None
handler = None
return handler

Expand All @@ -89,7 +102,8 @@ def response(self, request, exception):
:return: Wrap the return value obtained from :func:`default`
or registered handler for that type of exception.
"""
handler = self.lookup(exception)
route_name = request.name if request else None
handler = self.lookup(exception, route_name)
response = None
try:
if handler:
Expand Down
26 changes: 16 additions & 10 deletions tests/test_exceptions_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,24 @@ class ModuleNotFoundError(ImportError):
handler.add(CustomError, custom_error_handler)
handler.add(ServerError, server_error_handler)

assert handler.lookup(ImportError()) == import_error_handler
assert handler.lookup(ModuleNotFoundError()) == import_error_handler
assert handler.lookup(CustomError()) == custom_error_handler
assert handler.lookup(ServerError("Error")) == server_error_handler
assert handler.lookup(CustomServerError("Error")) == server_error_handler
assert handler.lookup(ImportError(), None) == import_error_handler
assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler
assert handler.lookup(CustomError(), None) == custom_error_handler
assert handler.lookup(ServerError("Error"), None) == server_error_handler
assert (
handler.lookup(CustomServerError("Error"), None)
== server_error_handler
)

# once again to ensure there is no caching bug
assert handler.lookup(ImportError()) == import_error_handler
assert handler.lookup(ModuleNotFoundError()) == import_error_handler
assert handler.lookup(CustomError()) == custom_error_handler
assert handler.lookup(ServerError("Error")) == server_error_handler
assert handler.lookup(CustomServerError("Error")) == server_error_handler
assert handler.lookup(ImportError(), None) == import_error_handler
assert handler.lookup(ModuleNotFoundError(), None) == import_error_handler
assert handler.lookup(CustomError(), None) == custom_error_handler
assert handler.lookup(ServerError("Error"), None) == server_error_handler
assert (
handler.lookup(CustomServerError("Error"), None)
== server_error_handler
)


def test_exception_handler_processed_request_middleware():
Expand Down