Skip to content

Commit

Permalink
Split HttpProtocol parts into base SanicProtocol and HTTPProtocol sub…
Browse files Browse the repository at this point in the history
…class (sanic-org#2229)

* Split HttpProtocol parts into base SanicProtocol and HTTPProtocol subclass.

* lint fixes

* re-black server.py
  • Loading branch information
ashleysommer authored and ChihweiLHBird committed Jun 1, 2022
1 parent 86e6b8c commit e0ea8a8
Showing 1 changed file with 132 additions and 61 deletions.
193 changes: 132 additions & 61 deletions sanic/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,25 +102,133 @@ def __init__(self, transport: TransportProtocol, unix=None):
self.client_port = addr[1]


class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
class SanicProtocol(asyncio.Protocol):
__slots__ = (
"app",
# event loop, connection
"loop",
"transport",
"connections",
"conn_info",
"signal",
"_can_write",
"_time",
"_task",
"_unix",
"_data_received",
)

def __init__(
self,
*,
loop,
app: Sanic,
signal=None,
connections=None,
unix=None,
**kwargs,
):
asyncio.set_event_loop(loop)
self.loop = loop
self.app: Sanic = app
self.signal = signal or Signal()
self.transport: Optional[Transport] = None
self.connections = connections if connections is not None else set()
self.conn_info: Optional[ConnInfo] = None
self._can_write = asyncio.Event()
self._can_write.set()
self._unix = unix
self._time = 0.0 # type: float
self._task = None # type: Optional[asyncio.Task]
self._data_received = asyncio.Event()

@property
def ctx(self):
if self.conn_info is not None:
return self.conn_info.ctx
else:
return None

async def send(self, data):
"""
Generic data write implementation with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
raise CancelledError
self.transport.write(data)
self._time = current_time()

async def receive_more(self):
"""
Wait until more data is received into the Server protocol's buffer
"""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()

def close(self):
"""
Force close the connection.
"""
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close()
self.transport = None

# asyncio.Protocol API Callbacks #
# ------------------------------ #
def connection_made(self, transport):
"""
Generic connection-made, with no connection_task, and no recv_buffer.
Override this for protocol-specific connection implementations.
"""
try:
transport.set_write_buffer_limits(low=16384, high=65536)
self.connections.add(self)
self.transport = transport
self.conn_info = ConnInfo(self.transport, unix=self._unix)
except Exception:
error_logger.exception("protocol.connect_made")

def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except BaseException:
error_logger.exception("protocol.connection_lost")

def pause_writing(self):
self._can_write.clear()

def resume_writing(self):
self._can_write.set()

def data_received(self, data: bytes):
try:
self._time = current_time()
if not data:
return self.close()

if self._data_received:
self._data_received.set()
except BaseException:
error_logger.exception("protocol.data_received")


class HttpProtocol(SanicProtocol, metaclass=TouchUpMeta):
"""
This class provides a basic HTTP implementation of the sanic framework.
This class provides implements the HTTP 1.1 protocol on top of our
Sanic Server transport
"""

__touchup__ = (
"send",
"connection_task",
)
__slots__ = (
# app
"app",
# event loop, connection
"loop",
"transport",
"connections",
"signal",
"conn_info",
"ctx",
# request params
"request",
# request config
Expand All @@ -137,14 +245,9 @@ class HttpProtocol(asyncio.Protocol, metaclass=TouchUpMeta):
"state",
"url",
"_handler_task",
"_can_write",
"_data_received",
"_time",
"_task",
"_http",
"_exception",
"recv_buffer",
"_unix",
)

def __init__(
Expand All @@ -158,16 +261,16 @@ def __init__(
unix=None,
**kwargs,
):
asyncio.set_event_loop(loop)
self.loop = loop
self.app: Sanic = app
super().__init__(
loop=loop,
app=app,
signal=signal,
connections=connections,
unix=unix,
)
self.url = None
self.transport: Optional[Transport] = None
self.conn_info: Optional[ConnInfo] = None
self.request: Optional[Request] = None
self.signal = signal or Signal()
self.access_log = self.app.config.ACCESS_LOG
self.connections = connections if connections is not None else set()
self.request_handler = self.app.handle_request
self.error_handler = self.app.error_handler
self.request_timeout = self.app.config.REQUEST_TIMEOUT
Expand All @@ -178,11 +281,7 @@ def __init__(
self.state = state if state else {}
if "requests_count" not in self.state:
self.state["requests_count"] = 0
self._data_received = asyncio.Event()
self._can_write = asyncio.Event()
self._can_write.set()
self._exception = None
self._unix = unix

def _setup_connection(self):
self._http = Http(self)
Expand Down Expand Up @@ -229,14 +328,6 @@ async def connection_task(self): # no cov
)
...

async def receive_more(self):
"""
Wait until more data is received into the Server protocol's buffer
"""
self.transport.resume_reading()
self._data_received.clear()
await self._data_received.wait()

def check_timeouts(self):
"""
Runs itself periodically to enforce any expired timeouts.
Expand Down Expand Up @@ -277,7 +368,7 @@ def check_timeouts(self):

async def send(self, data): # no cov
"""
Writes data with backpressure control.
Writes HTTP data with backpressure control.
"""
await self._can_write.wait()
if self.transport.is_closing():
Expand All @@ -301,20 +392,14 @@ def close_if_idle(self) -> bool:
return True
return False

def close(self):
"""
Force close the connection.
"""
# Cause a call to connection_lost where further cleanup occurs
if self.transport:
self.transport.close()
self.transport = None

# -------------------------------------------- #
# Only asyncio.Protocol callbacks below this
# -------------------------------------------- #

def connection_made(self, transport):
"""
HTTP-protocol-specific new connection handler
"""
try:
# TODO: Benchmark to find suitable write buffer limits
transport.set_write_buffer_limits(low=16384, high=65536)
Expand All @@ -326,30 +411,16 @@ def connection_made(self, transport):
except Exception:
error_logger.exception("protocol.connect_made")

def connection_lost(self, exc):
try:
self.connections.discard(self)
self.resume_writing()
if self._task:
self._task.cancel()
except Exception:
error_logger.exception("protocol.connection_lost")

def pause_writing(self):
self._can_write.clear()

def resume_writing(self):
self._can_write.set()

def data_received(self, data: bytes):

try:
self._time = current_time()
if not data:
return self.close()
self.recv_buffer += data

if (
len(self.recv_buffer) > self.app.config.REQUEST_BUFFER_SIZE
len(self.recv_buffer) >= self.app.config.REQUEST_BUFFER_SIZE
and self.transport
):
self.transport.pause_reading()
Expand Down

0 comments on commit e0ea8a8

Please sign in to comment.