Skip to content

Commit

Permalink
Use a client semaphore (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gallaecio authored Mar 20, 2024
1 parent e14d556 commit 55766c7
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 48 deletions.
8 changes: 4 additions & 4 deletions docs/asyncio_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ When using ``iter`` or multiple ``get`` calls, consider using a session:
Sessions improve performance through a pool of reusable connections to the Zyte
API server.

To send many queries with a concurrency limit, set ``n_conn`` in your client:
To send many queries with a concurrency limit, set ``n_conn`` in your client
(default is ``15``):

.. code-block:: python
client = AsyncZyteAPI(n_conn=15)
client = AsyncZyteAPI(n_conn=30)
Then use ``iter`` to send your queries. ``n_conn`` is not enforced when using
``get`` instead of ``iter``.
``n_conn`` will be enforce across all your ``get`` and ``iter`` calls.
23 changes: 23 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import asyncio
from unittest.mock import AsyncMock

import pytest

from zyte_api import AsyncZyteAPI
Expand Down Expand Up @@ -55,3 +58,23 @@ async def test_iter(mockserver):
assert Exception in expected_results
else:
assert actual_result in expected_results


@pytest.mark.asyncio
async def test_semaphore(mockserver):
client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/"))
client._semaphore = AsyncMock(wraps=client._semaphore)
queries = [
{"url": "https://a.example", "httpResponseBody": True},
{"url": "https://b.example", "httpResponseBody": True},
{"url": "https://c.example", "httpResponseBody": True},
]
futures = [
client.get(queries[0]),
next(iter(client.iter(queries[1:2]))),
client.get(queries[2]),
]
for future in asyncio.as_completed(futures):
await future
assert client._semaphore.__aenter__.call_count == len(queries)
assert client._semaphore.__aexit__.call_count == len(queries)
62 changes: 31 additions & 31 deletions zyte_api/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
self.agg_stats = AggStats()
self.retrying = retrying or zyte_api_retrying
self.user_agent = user_agent or USER_AGENT
self._semaphore = asyncio.Semaphore(n_conn)

async def get(
self,
Expand Down Expand Up @@ -74,26 +75,27 @@ async def request():
)

try:
async with post(**post_kwargs) as resp:
stats.record_connected(resp.status, self.agg_stats)
if resp.status >= 400:
content = await resp.read()
resp.release()
stats.record_read()
stats.record_request_error(content, self.agg_stats)

raise RequestError(
request_info=resp.request_info,
history=resp.history,
status=resp.status,
message=resp.reason,
headers=resp.headers,
response_content=content,
)

response = await resp.json()
stats.record_read(self.agg_stats)
return response
async with self._semaphore:
async with post(**post_kwargs) as resp:
stats.record_connected(resp.status, self.agg_stats)
if resp.status >= 400:
content = await resp.read()
resp.release()
stats.record_read()
stats.record_request_error(content, self.agg_stats)

raise RequestError(
request_info=resp.request_info,
history=resp.history,
status=resp.status,
message=resp.reason,
headers=resp.headers,
response_content=content,
)

response = await resp.json()
stats.record_read(self.agg_stats)
return response
except Exception as e:
if not isinstance(e, RequestError):
self.agg_stats.n_errors += 1
Expand Down Expand Up @@ -137,16 +139,14 @@ def iter(
Set the session TCPConnector limit to a value greater than
the number of connections.
"""
sem = asyncio.Semaphore(self.n_conn)

async def _request(query):
async with sem:
return await self.get(
query,
endpoint=endpoint,
session=session,
handle_retries=handle_retries,
retrying=retrying,
)

def _request(query):
return self.get(
query,
endpoint=endpoint,
session=session,
handle_retries=handle_retries,
retrying=retrying,
)

return asyncio.as_completed([_request(query) for query in queries])
31 changes: 18 additions & 13 deletions zyte_api/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@
from .constants import API_URL


def _get_loop():
try:
return asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop


class ZyteAPI:
"""Synchronous Zyte API client.
Expand Down Expand Up @@ -58,15 +67,15 @@ def get(
result = client.get({"url": "https://toscrape.com", "httpResponseBody": True})
"""
return asyncio.run(
self._async_client.get(
query=query,
endpoint=endpoint,
session=session,
handle_retries=handle_retries,
retrying=retrying,
)
loop = _get_loop()
future = self._async_client.get(
query=query,
endpoint=endpoint,
session=session,
handle_retries=handle_retries,
retrying=retrying,
)
return loop.run_until_complete(future)

def iter(
self,
Expand Down Expand Up @@ -97,11 +106,7 @@ def iter(
When exceptions occur, they are also yielded, not raised.
"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = _get_loop()
for future in self._async_client.iter(
queries=queries,
endpoint=endpoint,
Expand Down

0 comments on commit 55766c7

Please sign in to comment.