Skip to content
This repository has been archived by the owner on Apr 14, 2022. It is now read-only.

Commit

Permalink
Merge pull request #47 from pquentin/bleach-spike-defer-backend-import
Browse files Browse the repository at this point in the history
Ensure `import urllib3` can succeed even if trio and twisted aren't installed
  • Loading branch information
pquentin authored Jun 28, 2018
2 parents 9d6137a + b1df307 commit 365f649
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 26 deletions.
7 changes: 3 additions & 4 deletions demo/async-demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# This should work on python 3.6+

import urllib3
# TODO: less janky way of specifying backends
from urllib3._backends import TrioBackend, TwistedBackend
from urllib3.backends import Backend

URL = "http://httpbin.org/uuid"

Expand All @@ -15,11 +14,11 @@ async def main(backend):

print("--- urllib3 using Trio ---")
import trio
trio.run(main, TrioBackend())
trio.run(main, "trio")

print("\n--- urllib3 using Twisted ---")
from twisted.internet.task import react
from twisted.internet.defer import ensureDeferred
def twisted_main(reactor):
return ensureDeferred(main(TwistedBackend(reactor)))
return ensureDeferred(main(Backend("twisted", reactor=reactor)))
react(twisted_main)
58 changes: 58 additions & 0 deletions test/test_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest

import urllib3
from urllib3.backends import Backend
from urllib3._backends._loader import normalize_backend, load_backend


requires_async_pool_manager = pytest.mark.skipif(
not hasattr(urllib3, "AsyncPoolManager"),
reason="async backends require AsyncPoolManager",
)


class TestNormalizeBackend(object):
"""
Assert that we fail correctly if we attempt to use an unknown or incompatible backend.
"""
def test_unknown(self):
with pytest.raises(ValueError) as excinfo:
normalize_backend("_unknown", async_mode=False)

assert 'unknown backend specifier _unknown' == str(excinfo.value)

def test_sync(self):
assert normalize_backend(Backend("sync"), async_mode=False) == Backend("sync")
assert normalize_backend("sync", async_mode=False) == Backend("sync")
assert normalize_backend(None, async_mode=False) == Backend("sync")

with pytest.raises(ValueError) as excinfo:
normalize_backend(Backend("trio"), async_mode=False)
assert ('trio backend needs to be run in async mode' == str(excinfo.value))

@requires_async_pool_manager
def test_async(self):
assert normalize_backend(Backend("trio"), async_mode=True) == Backend("trio")
assert normalize_backend("twisted", async_mode=True) == Backend("twisted")

with pytest.raises(ValueError) as excinfo:
normalize_backend(Backend("sync"), async_mode=True)
assert ('sync backend needs to be run in sync mode' == str(excinfo.value))

from twisted.internet import reactor
assert (
normalize_backend(Backend("twisted", reactor=reactor), async_mode=True)
== Backend("twisted", reactor=reactor))


class TestLoadBackend(object):
"""
Assert that we can load a normalized backend
"""
def test_sync(self):
load_backend(normalize_backend("sync", async_mode=False))

@requires_async_pool_manager()
def test_async(self):
from twisted.internet import reactor
load_backend(Backend("twisted", reactor=reactor))
3 changes: 0 additions & 3 deletions test/test_no_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* HTTPS requests must fail with an error that points at the ssl module
"""

import pytest

import sys
if sys.version_info >= (2, 7):
import unittest
Expand Down Expand Up @@ -90,6 +88,5 @@ def import_ssl():

self.assertRaises(ImportError, import_ssl)

@pytest.mark.xfail
def test_import_urllib3(self):
import urllib3 # noqa: F401
2 changes: 0 additions & 2 deletions test/with_dummyserver/test_chunked_transfer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# -*- coding: utf-8 -*-

import pytest

from urllib3 import HTTPConnectionPool
from urllib3.exceptions import InvalidBodyError
from urllib3.packages import six
Expand Down
7 changes: 7 additions & 0 deletions test/with_dummyserver/test_no_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
from dummyserver.testcase import (
HTTPDummyServerTestCase, HTTPSDummyServerTestCase)

import pytest
import urllib3


class TestHTTPWithoutSSL(HTTPDummyServerTestCase, TestWithoutSSL):

@pytest.mark.skip(reason=(
"TestWithoutSSL mutates sys.modules."
"This breaks the backend loading code which imports modules at runtime."
"See discussion at https://github.com/python-trio/urllib3/pull/42"
))
def test_simple(self):
pool = urllib3.HTTPConnectionPool(self.host, self.port)
self.addCleanup(pool.close)
Expand Down
28 changes: 23 additions & 5 deletions urllib3/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,32 @@
)
from urllib3.packages import six
from ..util import ssl_ as ssl_util
from .._backends import SyncBackend
from .._backends._common import LoopAbort
from .._backends._loader import load_backend, normalize_backend

try:
import ssl
except ImportError:
ssl = None


def is_async_mode():
"""Tests if we're in the async part of the code or if we're bleached"""
async def f():
"""bleaching transforms async functions in sync functions"""
return None

obj = f()
if obj is None:
return False
else:
obj.close() # prevent unawaited coroutine warning
return True


_ASYNC_MODE = is_async_mode()


# When updating RECENT_DATE, move it to within two years of the current date,
# and not less than 6 months ago.
# Example: if Today is 2018-01-01, then RECENT_DATE should be any date on or
Expand Down Expand Up @@ -143,8 +160,10 @@ def all_pieces_iter():

yield state_machine.send(h11.EndOfMessage())

# Try to combine the header bytes + (first set of body bytes or end of message bytes) into one packet.
# As long as all_pieces_iter() yields at least two messages, this should never raise StopIteration.
# Try to combine the header bytes + (first set of body bytes or end of
# message bytes) into one packet.
# As long as all_pieces_iter() yields at least two messages, this should
# never raise StopIteration.
remaining_pieces = all_pieces_iter()
first_packet_bytes = next(remaining_pieces) + next(remaining_pieces)
all_pieces_combined_iter = itertools.chain([first_packet_bytes], remaining_pieces)
Expand Down Expand Up @@ -311,8 +330,7 @@ def __init__(self, host, port, backend=None,
source_address=None, tunnel_host=None, tunnel_port=None,
tunnel_headers=None):
self.is_verified = False

self._backend = backend or SyncBackend()
self._backend = load_backend(normalize_backend(backend, _ASYNC_MODE))
self._host = host
self._port = port
self._socket_options = (
Expand Down
9 changes: 0 additions & 9 deletions urllib3/_backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
from urllib3.packages import six
from .sync_backend import SyncBackend

__all__ = ['SyncBackend']

if six.PY3:
from .trio_backend import TrioBackend
from .twisted_backend import TwistedBackend
__all__ += ['TrioBackend', 'TwistedBackend']
82 changes: 82 additions & 0 deletions urllib3/_backends/_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from ..backends import Backend


class Loader:

def __init__(self, name, loader, is_async):
self.name = name
self.loader = loader
self.is_async = is_async

def __call__(self, *args, **kwargs):
return self.loader(kwargs)


def load_sync_backend(kwargs):
from .sync_backend import SyncBackend
return SyncBackend(**kwargs)


def load_trio_backend(kwargs):
from .trio_backend import TrioBackend
return TrioBackend(**kwargs)


def load_twisted_backend(kwargs):
from .twisted_backend import TwistedBackend
return TwistedBackend(**kwargs)


def backend_directory():
"""
We defer any heavy duty imports until the last minute.
"""
loaders = [
Loader(
name="sync",
loader=load_sync_backend,
is_async=False,
),
Loader(
name="trio",
loader=load_trio_backend,
is_async=True,
),
Loader(
name="twisted",
loader=load_twisted_backend,
is_async=True,
),
]
return {
loader.name: loader for loader in loaders
}


def normalize_backend(backend, async_mode):
if backend is None:
backend = Backend(name="sync") # sync backend is the default
elif not isinstance(backend, Backend):
backend = Backend(name=backend)

loaders_by_name = backend_directory()
if backend.name not in loaders_by_name:
raise ValueError("unknown backend specifier {}".format(backend.name))

loader = loaders_by_name[backend.name]

if async_mode and not loader.is_async:
raise ValueError("{} backend needs to be run in sync mode".format(
loader.name))

if not async_mode and loader.is_async:
raise ValueError("{} backend needs to be run in async mode".format(
loader.name))

return backend


def load_backend(backend):
loaders_by_name = backend_directory()
loader = loaders_by_name[backend.name]
return loader(backend.kwargs)
6 changes: 3 additions & 3 deletions urllib3/_backends/twisted_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import socket
import OpenSSL.crypto
from twisted.internet import protocol, ssl
from twisted.internet import protocol, reactor as default_reactor, ssl
from twisted.internet.interfaces import IHandshakeListener
from twisted.internet.endpoints import HostnameEndpoint, connectProtocol
from twisted.internet.defer import (
Expand All @@ -15,8 +15,8 @@


class TwistedBackend:
def __init__(self, reactor):
self._reactor = reactor
def __init__(self, reactor=None):
self._reactor = reactor or default_reactor

async def connect(self, host, port, connect_timeout,
source_address=None, socket_options=None):
Expand Down
12 changes: 12 additions & 0 deletions urllib3/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class Backend:
"""
Specifies the desired backend and any arguments passed to its constructor.
Projects that use urllib3 can subclass this interface to expose it to users.
"""
def __init__(self, name, **kwargs):
self.name = name
self.kwargs = kwargs

def __eq__(self, other):
return self.name == other.name and self.kwargs == other.kwargs

0 comments on commit 365f649

Please sign in to comment.