Skip to content

Commit

Permalink
Implement AsyncZyteAPI.session
Browse files Browse the repository at this point in the history
  • Loading branch information
Gallaecio committed Mar 18, 2024
1 parent 01a5a2b commit 467dc6e
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 26 deletions.
6 changes: 3 additions & 3 deletions docs/asyncio_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ parallel, using multiple connections:
import json
import sys
from zyte_api import AsyncZyteAPI, create_session
from zyte_api import AsyncZyteAPI
from zyte_api.aio.errors import RequestError
async def extract_from(urls, n_conn):
Expand All @@ -40,8 +40,8 @@ parallel, using multiple connections:
{"url": url, "browserHtml": True}
for url in urls
]
async with create_session(n_conn) as session:
res_iter = client.iter(requests, session=session)
async with client.session() as session:
res_iter = session.iter(requests)
for fut in res_iter:
try:
res = await fut
Expand Down
61 changes: 61 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,29 @@

from zyte_api import AsyncZyteAPI
from zyte_api.apikey import NoApiKey
from zyte_api.utils import USER_AGENT

import pytest


@pytest.mark.parametrize(
"user_agent,expected",
(
(
None,
USER_AGENT,
),
(
f'scrapy-zyte-api/0.11.1 {USER_AGENT}',
f'scrapy-zyte-api/0.11.1 {USER_AGENT}',
),
),
)
def test_user_agent(user_agent, expected):
client = AsyncZyteAPI(api_key='123', api_url='http:\\test', user_agent=user_agent)
assert client.user_agent == expected


def test_api_key():
AsyncZyteAPI(api_key="a")
with pytest.raises(NoApiKey):
Expand Down Expand Up @@ -46,3 +65,45 @@ async def test_iter(mockserver):
assert Exception in expected_results
else:
assert actual_result in expected_results


@pytest.mark.asyncio
async def test_session(mockserver):
client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/"))
queries = [
{"url": "https://a.example", "httpResponseBody": True},
{"url": "https://exception.example", "httpResponseBody": True},
{"url": "https://b.example", "httpResponseBody": True},
]
expected_results = [
{"url": "https://a.example", "httpResponseBody": "PGh0bWw+PGJvZHk+SGVsbG88aDE+V29ybGQhPC9oMT48L2JvZHk+PC9odG1sPg=="},
Exception,
{"url": "https://b.example", "httpResponseBody": "PGh0bWw+PGJvZHk+SGVsbG88aDE+V29ybGQhPC9oMT48L2JvZHk+PC9odG1sPg=="},
]
actual_results = []
async with client.session() as session:
assert session._context.connector.limit == client.n_conn
actual_results.append(await session.get(queries[0]))
for future in session.iter(queries[1:]):
try:
result = await future
except Exception as e:
result = e
actual_results.append(result)
aiohttp_session = session._context
assert not aiohttp_session.closed
assert aiohttp_session.closed
assert session._context is None

with pytest.raises(RuntimeError):
await session.get(queries[0])

with pytest.raises(RuntimeError):
session.iter(queries[1:])

assert len(actual_results) == len(expected_results)
for actual_result in actual_results:
if isinstance(actual_result, Exception):
assert Exception in expected_results
else:
assert actual_result in expected_results
22 changes: 0 additions & 22 deletions tests/test_client.py

This file was deleted.

2 changes: 1 addition & 1 deletion zyte_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""

from ._async import AsyncZyteAPI
from ._utils import create_session
from ._utils import deprecated_create_session as create_session
59 changes: 59 additions & 0 deletions zyte_api/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,62 @@ def _post_func(session):
else:
return session.post

class _AsyncSession:
def __init__(self, client, **session_kwargs):
self._client = client
self._session = create_session(client.n_conn, **session_kwargs)
self._context = None

async def __aenter__(self):
self._context = await self._session.__aenter__()
return self

async def __aexit__(self, *exc_info):
result = await self._context.__aexit__(*exc_info)
self._context = None
return result

def _check_context(self):
if self._context is None:
raise RuntimeError(
"Attempt to use session method on a session either not opened "
"or already closed."
)

async def get(
self,
query: dict,
*,
endpoint: str = 'extract',
handle_retries=True,
retrying: Optional[AsyncRetrying] = None,
):
self._check_context()
return await self._client.get(
query=query,
endpoint=endpoint,
handle_retries=handle_retries,
retrying=retrying,
session=self._context,
)

def iter(
self,
queries: List[dict],
*,
endpoint: str = 'extract',
handle_retries=True,
retrying: Optional[AsyncRetrying] = None,
) -> Iterator[asyncio.Future]:
self._check_context()
return self._client.iter(
queries=queries,
endpoint=endpoint,
session=self._context,
handle_retries=handle_retries,
retrying=retrying,
)


class AsyncZyteAPI:
def __init__(
Expand Down Expand Up @@ -156,3 +212,6 @@ async def _request(query):
)

return asyncio.as_completed([_request(query) for query in queries])

def session(self, **kwargs):
return _AsyncSession(client=self, **kwargs)
6 changes: 6 additions & 0 deletions zyte_api/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import aiohttp
from aiohttp import TCPConnector
from warnings import warn

from .constants import API_TIMEOUT

Expand All @@ -10,6 +11,11 @@
_AIO_API_TIMEOUT = aiohttp.ClientTimeout(total=API_TIMEOUT + 120)


def deprecated_create_session(connection_pool_size=100, **kwargs) -> aiohttp.ClientSession:
warn("zyte_api.create_session is deprecated, use AsyncZyteAPI.session instead.", DeprecationWarning)
return create_session(connection_pool_size=connection_pool_size, **kwargs)


def create_session(connection_pool_size=100, **kwargs) -> aiohttp.ClientSession:
""" Create a session with parameters suited for Zyte API """
kwargs.setdefault('timeout', _AIO_API_TIMEOUT)
Expand Down

0 comments on commit 467dc6e

Please sign in to comment.