diff --git a/docs/asyncio_api.rst b/docs/asyncio_api.rst index 2c6e568..ce69e6e 100644 --- a/docs/asyncio_api.rst +++ b/docs/asyncio_api.rst @@ -31,7 +31,7 @@ parallel, using multiple connections: import json import sys - from zyte_api import AsyncZyteAPI, create_session + from zyte_api import AsyncZyteAPI from zyte_api.aio.errors import RequestError async def extract_from(urls, n_conn): @@ -40,8 +40,8 @@ parallel, using multiple connections: {"url": url, "browserHtml": True} for url in urls ] - async with create_session(n_conn) as session: - res_iter = client.iter(requests, session=session) + async with client.session() as session: + res_iter = session.iter(requests) for fut in res_iter: try: res = await fut diff --git a/tests/test_async.py b/tests/test_async.py index 0d15df0..90bad9e 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -2,10 +2,29 @@ from zyte_api import AsyncZyteAPI from zyte_api.apikey import NoApiKey +from zyte_api.utils import USER_AGENT import pytest +@pytest.mark.parametrize( + "user_agent,expected", + ( + ( + None, + USER_AGENT, + ), + ( + f'scrapy-zyte-api/0.11.1 {USER_AGENT}', + f'scrapy-zyte-api/0.11.1 {USER_AGENT}', + ), + ), +) +def test_user_agent(user_agent, expected): + client = AsyncZyteAPI(api_key='123', api_url='http:\\test', user_agent=user_agent) + assert client.user_agent == expected + + def test_api_key(): AsyncZyteAPI(api_key="a") with pytest.raises(NoApiKey): @@ -46,3 +65,45 @@ async def test_iter(mockserver): assert Exception in expected_results else: assert actual_result in expected_results + + +@pytest.mark.asyncio +async def test_session(mockserver): + client = AsyncZyteAPI(api_key="a", api_url=mockserver.urljoin("/")) + queries = [ + {"url": "https://a.example", "httpResponseBody": True}, + {"url": "https://exception.example", "httpResponseBody": True}, + {"url": "https://b.example", "httpResponseBody": True}, + ] + expected_results = [ + {"url": "https://a.example", "httpResponseBody": "PGh0bWw+PGJvZHk+SGVsbG88aDE+V29ybGQhPC9oMT48L2JvZHk+PC9odG1sPg=="}, + Exception, + {"url": "https://b.example", "httpResponseBody": "PGh0bWw+PGJvZHk+SGVsbG88aDE+V29ybGQhPC9oMT48L2JvZHk+PC9odG1sPg=="}, + ] + actual_results = [] + async with client.session() as session: + assert session._context.connector.limit == client.n_conn + actual_results.append(await session.get(queries[0])) + for future in session.iter(queries[1:]): + try: + result = await future + except Exception as e: + result = e + actual_results.append(result) + aiohttp_session = session._context + assert not aiohttp_session.closed + assert aiohttp_session.closed + assert session._context is None + + with pytest.raises(RuntimeError): + await session.get(queries[0]) + + with pytest.raises(RuntimeError): + session.iter(queries[1:]) + + assert len(actual_results) == len(expected_results) + for actual_result in actual_results: + if isinstance(actual_result, Exception): + assert Exception in expected_results + else: + assert actual_result in expected_results diff --git a/tests/test_client.py b/tests/test_client.py deleted file mode 100644 index cd4c214..0000000 --- a/tests/test_client.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from zyte_api.aio.client import AsyncClient -from zyte_api.utils import USER_AGENT - - -@pytest.mark.parametrize( - "user_agent,expected", - ( - ( - None, - USER_AGENT, - ), - ( - f'scrapy-zyte-api/0.11.1 {USER_AGENT}', - f'scrapy-zyte-api/0.11.1 {USER_AGENT}', - ), - ), -) -def test_user_agent(user_agent, expected): - client = AsyncClient(api_key='123', api_url='http:\\test', user_agent=user_agent) - assert client.user_agent == expected diff --git a/zyte_api/__init__.py b/zyte_api/__init__.py index ed9f58f..0ea15f0 100644 --- a/zyte_api/__init__.py +++ b/zyte_api/__init__.py @@ -3,4 +3,4 @@ """ from ._async import AsyncZyteAPI -from ._utils import create_session +from ._utils import deprecated_create_session as create_session diff --git a/zyte_api/_async.py b/zyte_api/_async.py index 8156bde..77997ab 100644 --- a/zyte_api/_async.py +++ b/zyte_api/_async.py @@ -30,6 +30,62 @@ def _post_func(session): else: return session.post +class _AsyncSession: + def __init__(self, client, **session_kwargs): + self._client = client + self._session = create_session(client.n_conn, **session_kwargs) + self._context = None + + async def __aenter__(self): + self._context = await self._session.__aenter__() + return self + + async def __aexit__(self, *exc_info): + result = await self._context.__aexit__(*exc_info) + self._context = None + return result + + def _check_context(self): + if self._context is None: + raise RuntimeError( + "Attempt to use session method on a session either not opened " + "or already closed." + ) + + async def get( + self, + query: dict, + *, + endpoint: str = 'extract', + handle_retries=True, + retrying: Optional[AsyncRetrying] = None, + ): + self._check_context() + return await self._client.get( + query=query, + endpoint=endpoint, + handle_retries=handle_retries, + retrying=retrying, + session=self._context, + ) + + def iter( + self, + queries: List[dict], + *, + endpoint: str = 'extract', + handle_retries=True, + retrying: Optional[AsyncRetrying] = None, + ) -> Iterator[asyncio.Future]: + self._check_context() + return self._client.iter( + queries=queries, + endpoint=endpoint, + session=self._context, + handle_retries=handle_retries, + retrying=retrying, + ) + class AsyncZyteAPI: def __init__( @@ -156,3 +212,6 @@ async def _request(query): ) return asyncio.as_completed([_request(query) for query in queries]) + + def session(self, **kwargs): + return _AsyncSession(client=self, **kwargs) diff --git a/zyte_api/_utils.py b/zyte_api/_utils.py index e547fb5..86d1385 100644 --- a/zyte_api/_utils.py +++ b/zyte_api/_utils.py @@ -1,5 +1,6 @@ import aiohttp from aiohttp import TCPConnector +from warnings import warn from .constants import API_TIMEOUT @@ -10,6 +11,11 @@ _AIO_API_TIMEOUT = aiohttp.ClientTimeout(total=API_TIMEOUT + 120) +def deprecated_create_session(connection_pool_size=100, **kwargs) -> aiohttp.ClientSession: + warn("zyte_api.create_session is deprecated, use AsyncZyteAPI.session instead.", DeprecationWarning) + return create_session(connection_pool_size=connection_pool_size, **kwargs) + + def create_session(connection_pool_size=100, **kwargs) -> aiohttp.ClientSession: """ Create a session with parameters suited for Zyte API """ kwargs.setdefault('timeout', _AIO_API_TIMEOUT)