Skip to content

Commit

Permalink
Fix client connection header not reflecting connector force_close v…
Browse files Browse the repository at this point in the history
…alue (#10003)

(cherry picked from commit 78d1be5)
  • Loading branch information
bdraco committed Nov 20, 2024
1 parent e4bd744 commit 0743be3
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGES/10003.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed the HTTP client not considering the connector's ``force_close`` value when setting the ``Connection`` header -- by :user:`bdraco`.
29 changes: 7 additions & 22 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,15 +634,6 @@ def update_proxy(
proxy_headers = CIMultiDict(proxy_headers)
self.proxy_headers = proxy_headers

def keep_alive(self) -> bool:
if self.version >= HttpVersion11:
return self.headers.get(hdrs.CONNECTION) != "close"
if self.version == HttpVersion10:
# no headers means we close for Http 1.0
return self.headers.get(hdrs.CONNECTION) == "keep-alive"
# keep alive not supported at all
return False

async def write_bytes(
self, writer: AbstractStreamWriter, conn: "Connection"
) -> None:
Expand Down Expand Up @@ -737,21 +728,15 @@ async def send(self, conn: "Connection") -> "ClientResponse":
):
self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"

# set the connection header
connection = self.headers.get(hdrs.CONNECTION)
if not connection:
if self.keep_alive():
if self.version == HttpVersion10:
connection = "keep-alive"
else:
if self.version == HttpVersion11:
connection = "close"

if connection is not None:
self.headers[hdrs.CONNECTION] = connection
v = self.version
if hdrs.CONNECTION not in self.headers:
if conn._connector.force_close:
if v == HttpVersion11:
self.headers[hdrs.CONNECTION] = "close"
elif v == HttpVersion10:
self.headers[hdrs.CONNECTION] = "keep-alive"

# status + headers
v = self.version
status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
await writer.write_headers(status_line, self.headers)
task: Optional["asyncio.Task[None]"]
Expand Down
6 changes: 6 additions & 0 deletions tests/test_benchmarks_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,16 @@ async def _drain_helper(self) -> None:
def start_timeout(self) -> None:
"""Swallow start_timeout."""

class MockConnector:

def __init__(self) -> None:
self.force_close = False

class MockConnection:
def __init__(self) -> None:
self.transport = None
self.protocol = MockProtocol()
self._connector = MockConnector()

conn = MockConnection()

Expand Down
58 changes: 21 additions & 37 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_gen_default_accept_encoding,
_merge_ssl_params,
)
from aiohttp.http import HttpVersion
from aiohttp.http import HttpVersion10, HttpVersion11
from aiohttp.test_utils import make_mocked_coro


Expand Down Expand Up @@ -141,30 +141,6 @@ def test_version_err(make_request) -> None:
make_request("get", "http://python.org/", version="1.c")


def test_keep_alive(make_request) -> None:
req = make_request("get", "http://python.org/", version=(0, 9))
assert not req.keep_alive()

req = make_request("get", "http://python.org/", version=(1, 0))
assert not req.keep_alive()

req = make_request(
"get",
"http://python.org/",
version=(1, 0),
headers={"connection": "keep-alive"},
)
assert req.keep_alive()

req = make_request("get", "http://python.org/", version=(1, 1))
assert req.keep_alive()

req = make_request(
"get", "http://python.org/", version=(1, 1), headers={"connection": "close"}
)
assert not req.keep_alive()


def test_host_port_default_http(make_request) -> None:
req = make_request("get", "http://python.org/")
assert req.host == "python.org"
Expand Down Expand Up @@ -628,32 +604,40 @@ def test_gen_netloc_no_port(make_request) -> None:
)


async def test_connection_header(loop, conn) -> None:
async def test_connection_header(
loop: asyncio.AbstractEventLoop, conn: mock.Mock
) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)
req.keep_alive = mock.Mock()
req.headers.clear()

req.keep_alive.return_value = True
req.version = HttpVersion(1, 1)
req.version = HttpVersion11
req.headers.clear()
await req.send(conn)
with mock.patch.object(conn._connector, "force_close", False):
await req.send(conn)
assert req.headers.get("CONNECTION") is None

req.version = HttpVersion(1, 0)
req.version = HttpVersion10
req.headers.clear()
await req.send(conn)
with mock.patch.object(conn._connector, "force_close", False):
await req.send(conn)
assert req.headers.get("CONNECTION") == "keep-alive"

req.keep_alive.return_value = False
req.version = HttpVersion(1, 1)
req.version = HttpVersion11
req.headers.clear()
await req.send(conn)
with mock.patch.object(conn._connector, "force_close", True):
await req.send(conn)
assert req.headers.get("CONNECTION") == "close"

await req.close()
req.version = HttpVersion10
req.headers.clear()
with mock.patch.object(conn._connector, "force_close", True):
await req.send(conn)
assert not req.headers.get("CONNECTION")


async def test_no_content_length(loop, conn) -> None:
async def test_no_content_length(
loop: asyncio.AbstractEventLoop, conn: mock.Mock
) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)
resp = await req.send(conn)
assert req.headers.get("CONTENT-LENGTH") is None
Expand Down
5 changes: 2 additions & 3 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,8 @@ async def handler(request):
await resp.release()


@pytest.mark.xfail
async def test_http10_keep_alive_default(aiohttp_client) -> None:
async def handler(request):
async def test_http10_keep_alive_default(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
return web.Response()

app = web.Application()
Expand Down

0 comments on commit 0743be3

Please sign in to comment.