Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #512: force sendfile fallback when using SSL. #513

Merged
merged 2 commits into from
Sep 19, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion aiohttp/web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,16 @@ def _sendfile_system(self, req, resp, fobj, count):

`count` should be an integer > 0.
"""
transport = req.transport

if transport.get_extra_info("sslcontext"):
yield from self._sendfile_fallback(req, resp, fobj, count)
return

yield from resp.drain()

loop = req.app.loop
out_fd = req.transport.get_extra_info("socket").fileno()
out_fd = transport.get_extra_info("socket").fileno()
in_fd = fobj.fileno()
fut = asyncio.Future(loop=loop)

Expand Down
59 changes: 53 additions & 6 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
import os.path
import socket
import unittest
from aiohttp import log, web, request, FormData, ClientSession
from aiohttp import log, web, request, FormData, ClientSession, TCPConnector
from aiohttp.multidict import MultiDict
from aiohttp.protocol import HttpVersion, HttpVersion10, HttpVersion11
from aiohttp.streams import EOF_MARKER

try:
import ssl
except:
ssl = False


class WebFunctionalSetupMixin:

Expand All @@ -34,7 +39,7 @@ def find_unused_port(self):
return port

@asyncio.coroutine
def create_server(self, method, path, handler=None):
def create_server(self, method, path, handler=None, ssl_ctx=None):
app = web.Application(loop=self.loop)
if handler:
app.router.add_route(method, path, handler)
Expand All @@ -44,8 +49,9 @@ def create_server(self, method, path, handler=None):
debug=True, keep_alive_on=False,
access_log=log.access_logger)
srv = yield from self.loop.create_server(
self.handler, '127.0.0.1', port)
url = "http://127.0.0.1:{}".format(port) + path
self.handler, '127.0.0.1', port, ssl=ssl_ctx)
protocol = "https" if ssl_ctx else "http"
url = "{}://127.0.0.1:{}".format(protocol, port) + path
self.addCleanup(srv.close)
return app, srv, url

Expand Down Expand Up @@ -732,8 +738,10 @@ def go():
class StaticFileMixin(WebFunctionalSetupMixin):

@asyncio.coroutine
def create_server(self, method, path):
app, srv, url = yield from super().create_server(method, path)
def create_server(self, method, path, ssl_ctx=None):
app, srv, url = yield from super().create_server(
method, path, ssl_ctx=ssl_ctx
)
app.router.add_static = self.patch_sendfile(app.router.add_static)

return app, srv, url
Expand Down Expand Up @@ -768,6 +776,45 @@ def go(dirname, filename):
filename = 'data.unknown_mime_type'
self.loop.run_until_complete(go(here, filename))

@unittest.skipUnless(ssl, "ssl not supported")
def test_static_file_ssl(self):

@asyncio.coroutine
def go(dirname, filename):
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ssl_ctx.load_cert_chain(
os.path.join(dirname, 'sample.crt'),
os.path.join(dirname, 'sample.key')
)
app, _, url = yield from self.create_server(
'GET', '/static/' + filename, ssl_ctx=ssl_ctx
)
app.router.add_static('/static', dirname)

conn = TCPConnector(verify_ssl=False, loop=self.loop)
session = ClientSession(connector=conn)

resp = yield from session.request('GET', url)
self.assertEqual(200, resp.status)
txt = yield from resp.text()
self.assertEqual('file content', txt.rstrip())
ct = resp.headers['CONTENT-TYPE']
self.assertEqual('application/octet-stream', ct)
self.assertEqual(resp.headers.get('CONTENT-ENCODING'), None)
resp.close()

resp = yield from session.request('GET', url + 'fake')
self.assertEqual(404, resp.status)
resp.close()

resp = yield from session.request('GET', url + '/../../')
self.assertEqual(404, resp.status)
resp.close()

here = os.path.dirname(__file__)
filename = 'data.unknown_mime_type'
self.loop.run_until_complete(go(here, filename))

def test_static_file_with_content_type(self):

@asyncio.coroutine
Expand Down