Skip to content

Commit

Permalink
Async writer (#2774)
Browse files Browse the repository at this point in the history
* Await all writer.write() calls

* Convert tests

* Convert web functional tests

* Make StreamWriter.write() a native coroutine
  • Loading branch information
asvetlov authored Feb 27, 2018
1 parent 3e73618 commit fbfa53c
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 49 deletions.
2 changes: 1 addition & 1 deletion aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ async def write_bytes(self, writer, conn):
self.body = (self.body,)

for chunk in self.body:
writer.write(chunk)
await writer.write(chunk)

await writer.write_eof()
except OSError as exc:
Expand Down
11 changes: 4 additions & 7 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import zlib

from .abc import AbstractStreamWriter
from .helpers import noop


__all__ = ('StreamWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11')
Expand Down Expand Up @@ -56,7 +55,7 @@ def _write(self, chunk):
raise asyncio.CancelledError('Cannot write to closing transport')
self._transport.write(chunk)

def write(self, chunk, *, drain=True, LIMIT=64*1024):
async def write(self, chunk, *, drain=True, LIMIT=64*1024):
"""Writes chunk of data to a stream.
write_eof() indicates end of stream.
Expand All @@ -66,7 +65,7 @@ def write(self, chunk, *, drain=True, LIMIT=64*1024):
if self._compress is not None:
chunk = self._compress.compress(chunk)
if not chunk:
return noop()
return

if self.length is not None:
chunk_len = len(chunk)
Expand All @@ -76,7 +75,7 @@ def write(self, chunk, *, drain=True, LIMIT=64*1024):
chunk = chunk[:self.length]
self.length = 0
if not chunk:
return noop()
return

if chunk:
if self.chunked:
Expand All @@ -87,9 +86,7 @@ def write(self, chunk, *, drain=True, LIMIT=64*1024):

if self.buffer_size > LIMIT and drain:
self.buffer_size = 0
return self.drain()

return noop()
await self.drain()

async def write_headers(self, status_line, headers, SEP=': ', END='\r\n'):
"""Write request/response status and headers."""
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ async def write(self, writer):
field = self._value
chunk = await field.read_chunk(size=2**16)
while chunk:
writer.write(field.decode(chunk))
await writer.write(field.decode(chunk))
chunk = await field.read_chunk(size=2**16)


Expand Down
3 changes: 2 additions & 1 deletion aiohttp/web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ async def _default_expect_handler(request):
expect = request.headers.get(hdrs.EXPECT)
if request.version == HttpVersion11:
if expect.lower() == "100-continue":
request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n", drain=False)
await request.writer.write(
b"HTTP/1.1 100 Continue\r\n\r\n", drain=False)
else:
raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect)

Expand Down
14 changes: 7 additions & 7 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,9 +835,9 @@ async def test_expect_100_continue_header(loop, conn):

async def test_data_stream(loop, buf, conn):
@aiohttp.streamer
def gen(writer):
writer.write(b'binary data')
writer.write(b' result')
async def gen(writer):
await writer.write(b'binary data')
await writer.write(b' result')

req = ClientRequest(
'POST', URL('http://python.org/'), data=gen(), loop=loop)
Expand Down Expand Up @@ -876,7 +876,7 @@ async def test_data_stream_exc(loop, conn):

@aiohttp.streamer
async def gen(writer):
writer.write(b'binary data')
await writer.write(b'binary data')
await fut

req = ClientRequest(
Expand Down Expand Up @@ -929,8 +929,8 @@ async def throw_exc():
async def test_data_stream_continue(loop, buf, conn):
@aiohttp.streamer
async def gen(writer):
writer.write(b'binary data')
writer.write(b' result')
await writer.write(b'binary data')
await writer.write(b' result')
await writer.write_eof()

req = ClientRequest(
Expand Down Expand Up @@ -975,7 +975,7 @@ async def test_close(loop, buf, conn):
@aiohttp.streamer
async def gen(writer):
await asyncio.sleep(0.00001, loop=loop)
writer.write(b'result')
await writer.write(b'result')

req = ClientRequest(
'POST', URL('http://python.org/'), data=gen(), loop=loop)
Expand Down
52 changes: 26 additions & 26 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from aiohttp import http
from aiohttp.test_utils import make_mocked_coro


@pytest.fixture
Expand All @@ -28,8 +29,7 @@ def write(chunk):
@pytest.fixture
def protocol(loop, transport):
protocol = mock.Mock(transport=transport)
protocol._drain_helper.return_value = loop.create_future()
protocol._drain_helper.return_value.set_result(None)
protocol._drain_helper = make_mocked_coro()
return protocol


Expand All @@ -43,8 +43,8 @@ async def test_write_payload_eof(transport, protocol, loop):
write = transport.write = mock.Mock()
msg = http.StreamWriter(protocol, transport, loop)

msg.write(b'data1')
msg.write(b'data2')
await msg.write(b'data1')
await msg.write(b'data2')
await msg.write_eof()

content = b''.join([c[1][0] for c in list(write.mock_calls)])
Expand All @@ -54,7 +54,7 @@ async def test_write_payload_eof(transport, protocol, loop):
async def test_write_payload_chunked(buf, protocol, transport, loop):
msg = http.StreamWriter(protocol, transport, loop)
msg.enable_chunking()
msg.write(b'data')
await msg.write(b'data')
await msg.write_eof()

assert b'4\r\ndata\r\n0\r\n\r\n' == buf
Expand All @@ -63,8 +63,8 @@ async def test_write_payload_chunked(buf, protocol, transport, loop):
async def test_write_payload_chunked_multiple(buf, protocol, transport, loop):
msg = http.StreamWriter(protocol, transport, loop)
msg.enable_chunking()
msg.write(b'data1')
msg.write(b'data2')
await msg.write(b'data1')
await msg.write(b'data2')
await msg.write_eof()

assert b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n' == buf
Expand All @@ -75,8 +75,8 @@ async def test_write_payload_length(protocol, transport, loop):

msg = http.StreamWriter(protocol, transport, loop)
msg.length = 2
msg.write(b'd')
msg.write(b'ata')
await msg.write(b'd')
await msg.write(b'ata')
await msg.write_eof()

content = b''.join([c[1][0] for c in list(write.mock_calls)])
Expand All @@ -88,8 +88,8 @@ async def test_write_payload_chunked_filter(protocol, transport, loop):

msg = http.StreamWriter(protocol, transport, loop)
msg.enable_chunking()
msg.write(b'da')
msg.write(b'ta')
await msg.write(b'da')
await msg.write(b'ta')
await msg.write_eof()

content = b''.join([c[1][0] for c in list(write.mock_calls)])
Expand All @@ -103,11 +103,11 @@ async def test_write_payload_chunked_filter_mutiple_chunks(
write = transport.write = mock.Mock()
msg = http.StreamWriter(protocol, transport, loop)
msg.enable_chunking()
msg.write(b'da')
msg.write(b'ta')
msg.write(b'1d')
msg.write(b'at')
msg.write(b'a2')
await msg.write(b'da')
await msg.write(b'ta')
await msg.write(b'1d')
await msg.write(b'at')
await msg.write(b'a2')
await msg.write_eof()
content = b''.join([c[1][0] for c in list(write.mock_calls)])
assert content.endswith(
Expand All @@ -123,7 +123,7 @@ async def test_write_payload_deflate_compression(protocol, transport, loop):
write = transport.write = mock.Mock()
msg = http.StreamWriter(protocol, transport, loop)
msg.enable_compression('deflate')
msg.write(b'data')
await msg.write(b'data')
await msg.write_eof()

chunks = [c[1][0] for c in list(write.mock_calls)]
Expand All @@ -141,32 +141,32 @@ async def test_write_payload_deflate_and_chunked(
msg.enable_compression('deflate')
msg.enable_chunking()

msg.write(b'da')
msg.write(b'ta')
await msg.write(b'da')
await msg.write(b'ta')
await msg.write_eof()

assert b'6\r\nKI,I\x04\x00\r\n0\r\n\r\n' == buf


def test_write_drain(protocol, transport, loop):
async def test_write_drain(protocol, transport, loop):
msg = http.StreamWriter(protocol, transport, loop)
msg.drain = mock.Mock()
msg.write(b'1' * (64 * 1024 * 2), drain=False)
msg.drain = make_mocked_coro()
await msg.write(b'1' * (64 * 1024 * 2), drain=False)
assert not msg.drain.called

msg.write(b'1', drain=True)
await msg.write(b'1', drain=True)
assert msg.drain.called
assert msg.buffer_size == 0


def test_write_to_closing_transport(protocol, transport, loop):
async def test_write_to_closing_transport(protocol, transport, loop):
msg = http.StreamWriter(protocol, transport, loop)

msg.write(b'Before closing')
await msg.write(b'Before closing')
transport.is_closing.return_value = True

with pytest.raises(asyncio.CancelledError):
msg.write(b'After closing')
await msg.write(b'After closing')


async def test_drain(protocol, transport, loop):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ async def expect_handler(request):
nonlocal expect_received
expect_received = True
if request.version == HttpVersion11:
request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")

form = FormData()
form.add_field('name', b'123',
Expand All @@ -487,7 +487,7 @@ async def expect_handler(request):
if auth_err:
raise web.HTTPForbidden()

request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")
await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")

form = FormData()
form.add_field('name', b'123',
Expand Down Expand Up @@ -737,11 +737,11 @@ async def test_response_with_streamer(aiohttp_client, fname):
data_size = len(data)

@aiohttp.streamer
def stream(writer, f_name):
async def stream(writer, f_name):
with f_name.open('rb') as f:
data = f.read(100)
while data:
yield from writer.write(data)
await writer.write(data)
data = f.read(100)

async def handler(request):
Expand All @@ -767,11 +767,11 @@ async def test_response_with_streamer_no_params(aiohttp_client, fname):
data_size = len(data)

@aiohttp.streamer
def stream(writer):
async def stream(writer):
with fname.open('rb') as f:
data = f.read(100)
while data:
yield from writer.write(data)
await writer.write(data)
data = f.read(100)

async def handler(request):
Expand Down

0 comments on commit fbfa53c

Please sign in to comment.