Skip to content

Commit

Permalink
fix: re-raise RuntimeError when uvloop raises RuntimeError during con…
Browse files Browse the repository at this point in the history
…nect (#105)
  • Loading branch information
bdraco authored Sep 30, 2024
1 parent b075f25 commit c8f1fa9
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import itertools
import socket
import sys
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Union

from . import staggered
from .types import AddrInfoType
Expand Down Expand Up @@ -73,7 +73,8 @@ async def start_connection(
addr_infos = _interleave_addrinfos(addr_infos, interleave)

sock: Optional[socket.socket] = None
exceptions: List[List[OSError]] = []
# uvloop can raise RuntimeError instead of OSError
exceptions: List[List[Union[OSError, RuntimeError]]] = []
if happy_eyeballs_delay is None or single_addr_info:
# not using happy eyeballs
for addrinfo in addr_infos:
Expand All @@ -82,7 +83,7 @@ async def start_connection(
current_loop, exceptions, addrinfo, local_addr_infos
)
break
except OSError:
except (RuntimeError, OSError):
continue
else: # using happy eyeballs
sock, _, _ = await staggered.staggered_race(
Expand Down Expand Up @@ -113,12 +114,20 @@ async def start_connection(
)
# If the errno is the same for all exceptions, raise
# an OSError with that errno.
first_errno = first_exception.errno
if all(
isinstance(exc, OSError) and exc.errno == first_errno
for exc in all_exceptions
if isinstance(first_exception, OSError):
first_errno = first_exception.errno
if all(
isinstance(exc, OSError) and exc.errno == first_errno
for exc in all_exceptions
):
raise OSError(first_errno, msg)
elif isinstance(first_exception, RuntimeError) and all(
isinstance(exc, RuntimeError) for exc in all_exceptions
):
raise OSError(first_errno, msg)
raise RuntimeError(msg)
# We have a mix of OSError and RuntimeError
# so we have to pick which one to raise.
# and we raise OSError for compatibility
raise OSError(msg)
finally:
all_exceptions = None # type: ignore[assignment]
Expand All @@ -129,12 +138,12 @@ async def start_connection(

async def _connect_sock(
loop: asyncio.AbstractEventLoop,
exceptions: List[List[OSError]],
exceptions: List[List[Union[OSError, RuntimeError]]],
addr_info: AddrInfoType,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
) -> socket.socket:
"""Create, bind and connect one socket."""
my_exceptions: list[OSError] = []
my_exceptions: List[Union[OSError, RuntimeError]] = []
exceptions.append(my_exceptions)
family, type_, proto, _, address = addr_info
sock = None
Expand Down Expand Up @@ -164,7 +173,7 @@ async def _connect_sock(
raise OSError(f"no matching local address with {family=} found")
await loop.sock_connect(sock, address)
return sock
except OSError as exc:
except (RuntimeError, OSError) as exc:
my_exceptions.append(exc)
if sock is not None:
sock.close()
Expand Down
283 changes: 283 additions & 0 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,289 @@ async def _sock_connect(
]


@patch_socket
@pytest.mark.asyncio
async def test_uvloop_runtime_error(
m_socket: ModuleType,
) -> None:
"""
Test RuntimeError is handled when connecting a socket with uvloop.
Connecting a socket can raise a RuntimeError, OSError or ValueError.
- OSError: If the address is invalid or the connection fails.
- ValueError: if a non-sock it passed (this should never happen).
https://github.com/python/cpython/blob/e44eebfc1eccdaaebc219accbfc705c9a9de068d/Lib/asyncio/selector_events.py#L271
- RuntimeError: If the file descriptor is already in use by a transport.
We should never get ValueError since we are using the correct types.
selector_events.py never seems to raise a RuntimeError, but it is possible
with uvloop. This test is to ensure that we handle it correctly.
"""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
raise RuntimeError("all fail")

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
RuntimeError, match="all fail"
):
assert (
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)
== mock_socket
)

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
("dead:aaaa::", 80, 0, 0),
("107.6.106.83", 80),
]


@patch_socket
@pytest.mark.asyncio
async def test_uvloop_different_runtime_error(
m_socket: ModuleType,
) -> None:
"""Test different RuntimeErrors are handled when connecting a socket with uvloop."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []
counter = 0

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
nonlocal counter
counter += 1
raise RuntimeError(counter)

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
RuntimeError, match="Multiple exceptions: 1, 2, 3"
):
assert (
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)
== mock_socket
)

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
("dead:aaaa::", 80, 0, 0),
("107.6.106.83", 80),
]


@patch_socket
@pytest.mark.asyncio
async def test_uvloop_mixing_os_and_runtime_error(
m_socket: ModuleType,
) -> None:
"""Test uvloop raising OSError and RuntimeError."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []
counter = 0

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
nonlocal counter
counter += 1
if counter == 1:
raise RuntimeError(counter)
raise OSError(counter, f"all fail {counter}")

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
OSError, match="Multiple exceptions: 1"
):
assert (
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)
== mock_socket
)

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
("dead:aaaa::", 80, 0, 0),
("107.6.106.83", 80),
]


@patch_socket
@pytest.mark.asyncio
@pytest.mark.xfail(reason="raises RuntimeError: coroutine ignored GeneratorExit")
Expand Down

0 comments on commit c8f1fa9

Please sign in to comment.