Skip to content

Commit

Permalink
Get rid of legacy class StreamWriter #2109 (#2651)
Browse files Browse the repository at this point in the history
* Get rid of legacy StreamWriter (#2623)

Legacy StreamWriter as a pure proxy of the transport and the protocol is
no longer needed. All of the functionalities that were behind this class
has been moved to the PayloadWriter.

Some changes that have to be considered that impacted during this change
* TCP Operations have been isolated in a module rather than move them
into the PayloadWriter
* WebSocketWriter had a dependency with the StreamWriter, to get rid of
that dependency the constructor has been modified to take the protocol
and the transport.

A next step changing the name PayLoadWriter for the StreamWriter to have
consistency with the reader part, might be considered.

* Add CHANGES

* Fixed invalid import order

* Fix test broken

* Fix tcp_cork issues

* Test PayloadWriter properties

* Avoid return useless values for tcp_<operations>

* Increase coverage http_writer

* Increase coverage web_protocol
  • Loading branch information
pfreixes authored and asvetlov committed Jan 11, 2018
1 parent 74810c2 commit f570fed
Show file tree
Hide file tree
Showing 19 changed files with 426 additions and 483 deletions.
1 change: 1 addition & 0 deletions CHANGES/2651.removal
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Get rid of the legacy class StreamWriter.
9 changes: 6 additions & 3 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .http import WS_KEY, WebSocketReader, WebSocketWriter
from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse
from .streams import FlowControlDataQueue
from .tcp_helpers import tcp_cork, tcp_nodelay
from .tracing import Trace


Expand Down Expand Up @@ -296,7 +297,8 @@ async def _request(self, method, url, *,
'Connection timeout '
'to host {0}'.format(url)) from exc

conn.writer.set_tcp_nodelay(True)
tcp_nodelay(conn.transport, True)
tcp_cork(conn.transport, False)
try:
resp = req.send(conn)
try:
Expand Down Expand Up @@ -575,12 +577,13 @@ async def _ws_connect(self, url, *,
notakeover = False

proto = resp.connection.protocol
transport = resp.connection.transport
reader = FlowControlDataQueue(
proto, limit=2 ** 16, loop=self._loop)
proto.set_parser(WebSocketReader(reader), reader)
resp.connection.writer.set_tcp_nodelay(True)
tcp_nodelay(transport, True)
writer = WebSocketWriter(
resp.connection.writer, use_mask=True,
proto, transport, use_mask=True,
compress=compress, notakeover=notakeover)
except Exception:
resp.close()
Expand Down
6 changes: 2 additions & 4 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .client_exceptions import (ClientOSError, ClientPayloadError,
ServerDisconnectedError)
from .http import HttpResponseParser, StreamWriter
from .http import HttpResponseParser
from .streams import EMPTY_PAYLOAD, DataQueue


Expand All @@ -17,7 +17,6 @@ def __init__(self, *, loop=None):

self.paused = False
self.transport = None
self.writer = None
self._should_close = False

self._message = None
Expand Down Expand Up @@ -60,7 +59,6 @@ def is_connected(self):

def connection_made(self, transport):
self.transport = transport
self.writer = StreamWriter(self, transport, self._loop)

def connection_lost(self, exc):
if self._payload_parser is not None:
Expand All @@ -82,7 +80,7 @@ def connection_lost(self, exc):
exc = ServerDisconnectedError(uncompleted)
DataQueue.set_exception(self, exc)

self.transport = self.writer = None
self.transport = None
self._should_close = True
self._parser = None
self._message = None
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def send(self, conn):
if self.url.raw_query_string:
path += '?' + self.url.raw_query_string

writer = PayloadWriter(conn.writer, self.loop)
writer = PayloadWriter(conn.protocol, conn.transport, self.loop)

if self.compress:
writer.enable_compression(self.compress)
Expand Down
3 changes: 1 addition & 2 deletions aiohttp/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
WSCloseCode, WSMessage, WSMsgType, ws_ext_gen,
ws_ext_parse)
from .http_writer import (HttpVersion, HttpVersion10, HttpVersion11,
PayloadWriter, StreamWriter)
PayloadWriter)


__all__ = (
'HttpProcessingError', 'RESPONSES', 'SERVER_SOFTWARE',

# .http_writer
'PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11',
'StreamWriter',

# .http_parser
'HttpParser', 'HttpRequestParser', 'HttpResponseParser',
Expand Down
16 changes: 8 additions & 8 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,11 +513,11 @@ def parse_frame(self, buf):

class WebSocketWriter:

def __init__(self, stream, *,
def __init__(self, protocol, transport, *,
use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(),
compress=0, notakeover=False):
self.stream = stream
self.writer = stream.transport
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
self.randrange = random.randrange
self.compress = compress
Expand Down Expand Up @@ -572,20 +572,20 @@ def _send_frame(self, message, opcode, compress=None):
mask = mask.to_bytes(4, 'big')
message = bytearray(message)
_websocket_mask(mask, message)
self.writer.write(header + mask + message)
self.transport.write(header + mask + message)
self._output_size += len(header) + len(mask) + len(message)
else:
if len(message) > MSG_SIZE:
self.writer.write(header)
self.writer.write(message)
self.transport.write(header)
self.transport.write(message)
else:
self.writer.write(header + message)
self.transport.write(header + message)

self._output_size += len(header) + len(message)

if self._output_size > self._limit:
self._output_size = 0
return self.stream.drain()
return self.protocol._drain_helper()

return noop()

Expand Down
104 changes: 18 additions & 86 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,104 +2,24 @@

import asyncio
import collections
import socket
import zlib
from contextlib import suppress

from .abc import AbstractPayloadWriter
from .helpers import noop


__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11',
'StreamWriter')
__all__ = ('PayloadWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11')

HttpVersion = collections.namedtuple('HttpVersion', ['major', 'minor'])
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)


if hasattr(socket, 'TCP_CORK'): # pragma: no cover
CORK = socket.TCP_CORK
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
CORK = socket.TCP_NOPUSH
else: # pragma: no cover
CORK = None


class StreamWriter:
class PayloadWriter(AbstractPayloadWriter):

def __init__(self, protocol, transport, loop):
self._protocol = protocol
self._loop = loop
self._tcp_nodelay = False
self._tcp_cork = False
self._socket = transport.get_extra_info('socket')
self._waiters = []
self.transport = transport

@property
def tcp_nodelay(self):
return self._tcp_nodelay

def set_tcp_nodelay(self, value):
value = bool(value)
if self._tcp_nodelay == value:
return
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return

# socket may be closed already, on windows OSError get raised
with suppress(OSError):
if self._tcp_cork:
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, False)
self._tcp_cork = False

self._socket.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
self._tcp_nodelay = value

@property
def tcp_cork(self):
return self._tcp_cork

def set_tcp_cork(self, value):
value = bool(value)
if self._tcp_cork == value:
return
if self._socket is None:
return
if self._socket.family not in (socket.AF_INET, socket.AF_INET6):
return

with suppress(OSError):
if self._tcp_nodelay:
self._socket.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, False)
self._tcp_nodelay = False
if CORK is not None: # pragma: no branch
self._socket.setsockopt(socket.IPPROTO_TCP, CORK, value)
self._tcp_cork = value

async def drain(self):
"""Flush the write buffer.
The intended use is to write
await w.write(data)
await w.drain()
"""
if self._protocol.transport is not None:
await self._protocol._drain_helper()


class PayloadWriter(AbstractPayloadWriter):

def __init__(self, stream, loop):
self._stream = stream
self._transport = None
self._transport = transport

self.loop = loop
self.length = None
Expand All @@ -110,11 +30,15 @@ def __init__(self, stream, loop):
self._eof = False
self._compress = None
self._drain_waiter = None
self._transport = self._stream.transport

async def get_transport(self):
@property
def transport(self):
return self._transport

@property
def protocol(self):
return self._protocol

def enable_chunking(self):
self.chunked = True

Expand Down Expand Up @@ -204,4 +128,12 @@ async def write_eof(self, chunk=b''):
self._transport = None

async def drain(self):
await self._stream.drain()
"""Flush the write buffer.
The intended use is to write
await w.write(data)
await w.drain()
"""
if self._protocol.transport is not None:
await self._protocol._drain_helper()
61 changes: 61 additions & 0 deletions aiohttp/tcp_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Helper methods to tune a TCP connection"""

import socket
from contextlib import suppress


__all__ = ('tcp_keepalive', 'tcp_nodelay', 'tcp_cork')


if hasattr(socket, 'TCP_CORK'): # pragma: no cover
CORK = socket.TCP_CORK
elif hasattr(socket, 'TCP_NOPUSH'): # pragma: no cover
CORK = socket.TCP_NOPUSH
else: # pragma: no cover
CORK = None


if hasattr(socket, 'SO_KEEPALIVE'):
def tcp_keepalive(transport):
sock = transport.get_extra_info('socket')
if sock is not None:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
else:
def tcp_keepalive(transport): # pragma: no cover
pass


def tcp_nodelay(transport, value):
sock = transport.get_extra_info('socket')

if sock is None:
return

if sock.family not in (socket.AF_INET, socket.AF_INET6):
return

value = bool(value)

# socket may be closed already, on windows OSError get raised
with suppress(OSError):
sock.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, value)


def tcp_cork(transport, value):
sock = transport.get_extra_info('socket')

if CORK is None:
return

if sock is None:
return

if sock.family not in (socket.AF_INET, socket.AF_INET6):
return

value = bool(value)

with suppress(OSError):
sock.setsockopt(
socket.IPPROTO_TCP, CORK, value)
9 changes: 4 additions & 5 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def _sendfile_cb(self, fut, out_fd, in_fd,
set_result(fut, None)

async def sendfile(self, fobj, count):
transport = await self.get_transport()

out_socket = transport.get_extra_info('socket').dup()
out_socket = self.transport.get_extra_info('socket').dup()
out_socket.setblocking(False)
out_fd = out_socket.fileno()
in_fd = fobj.fileno()
Expand All @@ -71,7 +69,7 @@ async def sendfile(self, fobj, count):
await fut
except Exception:
server_logger.debug('Socket error')
transport.close()
self.transport.close()
finally:
out_socket.close()

Expand Down Expand Up @@ -112,7 +110,8 @@ async def _sendfile_system(self, request, fobj, count):
writer = await self._sendfile_fallback(request, fobj, count)
else:
writer = SendfilePayloadWriter(
request._protocol.writer,
request.protocol,
transport,
request.loop
)
request._payload_writer = writer
Expand Down
Loading

0 comments on commit f570fed

Please sign in to comment.