diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b76c0cd5a..e9dde1cf79 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,4 +32,4 @@ repos: - id: mypy args: [--check-untyped-defs] exclude: 'tests/' - additional_dependencies: ['charset_normalizer', 'urllib3.future>=2.1.902', 'wassima>=1.0.1', 'idna', 'kiss_headers'] + additional_dependencies: ['charset_normalizer', 'urllib3.future>=2.2.901', 'wassima>=1.0.1', 'idna', 'kiss_headers'] diff --git a/HISTORY.md b/HISTORY.md index e7e583f7b6..d9dbde6b74 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,141 @@ Release History =============== +3.2.0 (2023-11-05) +------------------ + +**Changed** +- Changed method `raise_for_status` in class `Response` to return **self** in order to make the call chainable. + Idea taken from upstream https://github.com/psf/requests/issues/6215 +- Bump minimal version supported for `urllib3.future` to 2.2.901 for recently introduced added features (bellow). + +**Added** +- Support for multiplexed connection in HTTP/2 and HTTP/3. Concurrent requests per connection are now a thing, in synchronous code. + This feature is the real advantage of using binaries HTTP protocols. + It is disabled by default and can be enabled through `Session(multiplexed=True)`, each `Response` object will + be 'lazy' loaded. Accessing anything from returned `Response` will block the code until target response is retrieved. + Use `Session.gather()` to efficiently receive responses. You may also give a list of responses that you want to load. + + **Example A)** Emitting concurrent requests and loading them via `Session.gather()` + ```python + from niquests import Session + from time import time + + s = Session(multiplexed=True) + + before = time() + responses = [] + + responses.append( + s.get("https://pie.dev/delay/3") + ) + + responses.append( + s.get("https://pie.dev/delay/1") + ) + + s.gather() + + print(f"waited {time() - before} second(s)") # will print 3s + ``` + + **Example B)** Emitting concurrent requests and loading them via direct access + ```python + from niquests import Session + from time import time + + s = Session(multiplexed=True) + + before = time() + responses = [] + + responses.append( + s.get("https://pie.dev/delay/3") + ) + + responses.append( + s.get("https://pie.dev/delay/1") + ) + + # internally call gather with self (Response) + print(responses[0].status_code) # 200! :! Hidden call to s.gather(responses[0]) + print(responses[1].status_code) # 200! + + print(f"waited {time() - before} second(s)") # will print 3s + ``` + You have nothing to do, everything from streams to connection pooling are handled automagically! +- Support for in-memory intermediary/client certificate (mTLS). + Thanks for support within `urllib3.future`. Unfortunately this feature may not be available depending on your platform. + Passing `cert=(a, b, c)` where **a** or/and **b** contains directly the certificate is supported. + See https://urllib3future.readthedocs.io/en/latest/advanced-usage.html#in-memory-client-mtls-certificate for more information. + It is proposed to circumvent recent pyOpenSSL complete removal. +- Detect if a new (stable) version is available when invoking `python -m niquests.help` and propose it for installation. +- Add the possibility to disable a specific protocol (e.g. HTTP/2, and/or HTTP/3) when constructing `Session`. + Like so: `s = Session(disable_http2=..., disable_http3=...)` both options are set to `False`, thus letting them enabled. + urllib3.future does not permit to disable HTTP/1.1 for now. +- Support passing a single `str` to `auth=...` in addition to actually supported types. It will be treated as a + **Bearer** token, by default to the `Authorization` header. It's a shortcut. You may keep your own token prefix in given + string (e.g. if not Bearer). +- Added `MultiplexingError` exception for anything related to failure with a multiplexed connection. +- Added **async** support through `AsyncSession` that utilize an underlying thread pool. + ```python + from niquests import AsyncSession + import asyncio + from time import time + + async def emit() -> None: + responses = [] + + async with AsyncSession(multiplexed=True) as s: + responses.append(await s.get("https://pie.dev/get")) + responses.append(await s.get("https://pie.dev/head")) + + await s.gather() + + print(responses) + + async def main() -> None: + foo = asyncio.create_task(emit()) + bar = asyncio.create_task(emit()) + await foo + await bar + + if __name__ == "__main__": + before = time() + asyncio.run(main()) + print(time() - before) + ``` + Or without `multiplexing` if you want to keep multiple connections open per host per request. + ```python + from niquests import AsyncSession + import asyncio + from time import time + + async def emit() -> None: + responses = [] + + async with AsyncSession() as s: + responses.append(await s.get("https://pie.dev/get")) + responses.append(await s.get("https://pie.dev/head")) + + print(responses) + + async def main() -> None: + foo = asyncio.create_task(emit()) + bar = asyncio.create_task(emit()) + await foo + await bar + + if __name__ == "__main__": + before = time() + asyncio.run(main()) + print(time() - before) + ``` + You may disable concurrent threads by setting `AsyncSession.no_thread = True`. + +**Security** +- Certificate revocation verification may not be fired for subsequents requests in a specific condition (redirection). + 3.1.4 (2023-10-23) ------------------ diff --git a/README.md b/README.md index fc6722cb08..1710db2c16 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,7 @@ **Niquests** is a simple, yet elegant, HTTP library. It is a drop-in replacement for **Requests** that is no longer under feature freeze. -Why did we pursue this? We don't have to reinvent the wheel all over again, HTTP client **Requests** is well established and -really plaisant in its usage. We believe that **Requests** have the most inclusive, and developer friendly interfaces. We -intend to keep it that way. +Niquests, is the “**Safest**, **Fastest**, **Easiest**, and **Most advanced**” Python HTTP Client. ```python >>> import niquests @@ -28,8 +26,6 @@ True Niquests allows you to send HTTP requests extremely easily. There’s no need to manually add query strings to your URLs, or to form-encode your `PUT` & `POST` data — but nowadays, just use the `json` method! -Niquests is one of the least downloaded Python packages today, pulling in around `100+ download / week`— according to GitHub, Niquests is currently depended upon by `1+` repositories. But, that may change..! Starting with you. - [![Downloads](https://static.pepy.tech/badge/niquests/month)](https://pepy.tech/project/niquests) [![Supported Versions](https://img.shields.io/pypi/pyversions/niquests.svg)](https://pypi.org/project/niquests) @@ -41,7 +37,7 @@ Niquests is available on PyPI: $ python -m pip install niquests ``` -Niquests officially supports Python 3.7+. +Niquests officially supports Python or PyPy 3.7+. ## Supported Features & Best–Practices @@ -66,6 +62,14 @@ Niquests is ready for the demands of building robust and reliable HTTP–speakin - Streaming Downloads - HTTP/2 by default - HTTP/3 over QUIC +- Multiplexed! +- Async! + +## Why did we pursue this? + +We don't have to reinvent the wheel all over again, HTTP client **Requests** is well established and +really plaisant in its usage. We believe that **Requests** have the most inclusive, and developer friendly interfaces. +We intend to keep it that way. As long as we can, long live Niquests! --- diff --git a/docs/community/recommended.rst b/docs/community/recommended.rst index 1dafbe9a36..0cb8fefd0b 100644 --- a/docs/community/recommended.rst +++ b/docs/community/recommended.rst @@ -18,7 +18,7 @@ whenever you're making a lot of web niquests. Requests-Toolbelt ----------------- -`Requests-Toolbelt`_ is a collection of utilities that some users of Requests may desire, +`Requests-Toolbelt`_ is a collection of utilities that some users of Niquests may desire, but do not belong in Niquests proper. This library is actively maintained by members of the Requests core team, and reflects the functionality most requested by users within the community. diff --git a/docs/community/vulnerabilities.rst b/docs/community/vulnerabilities.rst index 61b0f55ab6..a9335b53a6 100644 --- a/docs/community/vulnerabilities.rst +++ b/docs/community/vulnerabilities.rst @@ -86,16 +86,4 @@ if upgrading is not an option. Previous CVEs ------------- -- Fixed in 2.20.0 - - `CVE 2018-18074 `_ - -- Fixed in 2.6.0 - - - `CVE 2015-2296 `_, - reported by Matthew Daley of `BugFuzz `_. - -- Fixed in 2.3.0 - - - `CVE 2014-1829 `_ - - - `CVE 2014-1830 `_ +None to date. diff --git a/docs/index.rst b/docs/index.rst index 94cff6f6ff..1f08f67224 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -76,6 +76,8 @@ Niquests is ready for today's web. - Streaming Downloads - HTTP/2 by default - HTTP/3 over QUIC +- Multiplexed! +- Async! Niquests officially supports Python 3.7+, and runs great on PyPy. diff --git a/docs/user/advanced.rst b/docs/user/advanced.rst index a82d4c453d..a43076c05c 100644 --- a/docs/user/advanced.rst +++ b/docs/user/advanced.rst @@ -1228,3 +1228,26 @@ by passing a custom ``QuicSharedCache`` instance like so:: .. note:: Passing ``None`` to max size actually permit the cache to grow indefinitely. This is unwise and can lead to significant RAM usage. When the cache is full, the oldest entry is removed. + +Disable HTTP/2, and/or HTTP/3 +----------------------------- + +You can at your own discretion disable a protocol by passing ``disable_http2=True`` or +``disable_http3=True`` within your ``Session`` constructor. + +.. warning:: It is actually forbidden to disable HTTP/1.1 as the underlying library (urllib3.future) does not permit it for now. + +Having a session without HTTP/2 enabled should be done that way:: + + import niquests + + session = niquests.Session(disable_http2=True) + + +Passing a bearer token +---------------------- + +You may use ``auth=my_token`` as a shortcut to passing ``headers={"Authorization": f"Bearer {my_token}"}`` in +get, post, request, etc... + +.. note:: If you pass a token with its custom prefix, it will be taken and passed as-is. e.g. ``auth="NotBearer eyDdx.."`` diff --git a/docs/user/quickstart.rst b/docs/user/quickstart.rst index 8fb4353b26..b816e8002e 100644 --- a/docs/user/quickstart.rst +++ b/docs/user/quickstart.rst @@ -165,6 +165,7 @@ failed response (e.g. error details with HTTP 500). Such JSON will be decoded and returned. To check that a request is successful, use ``r.raise_for_status()`` or check ``r.status_code`` is what you expect. +.. note:: Since Niquests 3.2, ``r.raise_for_status()`` is chainable as it returns self if everything went fine. Raw Response Content -------------------- @@ -621,6 +622,114 @@ It is saved in-memory by Niquests. You may also run the following command ``python -m niquests.help`` to find out if you support HTTP/3. In 95 percents of the case, the answer is yes! +Multiplexed Connection +---------------------- + +Starting from Niquests 3.2 you can issue concurrent requests without having multiple connections. +It can leverage multiplexing when your remote peer support either HTTP/2, or HTTP/3. + +The only thing you will ever have to do to get started is to specify ``multiplexed=True`` from +within your ``Session`` constructor. + +Any ``Response`` returned by get, post, put, etc... will be a lazy instance of ``Response``. + +.. note:: + + An important note about using ``Session(multiplexed=True)`` is that, in order to be efficient + and actually leverage its perks, you will have to issue multiple concurrent request before + actually trying to access any ``Response`` methods or attributes. + +**Example A)** Emitting concurrent requests and loading them via `Session.gather()`:: + + from niquests import Session + from time import time + + s = Session(multiplexed=True) + + before = time() + responses = [] + + responses.append( + s.get("https://pie.dev/delay/3") + ) + + responses.append( + s.get("https://pie.dev/delay/1") + ) + + s.gather() + + print(f"waited {time() - before} second(s)") # will print 3s + + +**Example B)** Emitting concurrent requests and loading them via direct access:: + + from niquests import Session + from time import time + + s = Session(multiplexed=True) + + before = time() + responses = [] + + responses.append( + s.get("https://pie.dev/delay/3") + ) + + responses.append( + s.get("https://pie.dev/delay/1") + ) + + # internally call gather with self (Response) + print(responses[0].status_code) # 200! :! Hidden call to s.gather(responses[0]) + print(responses[1].status_code) # 200! + + print(f"waited {time() - before} second(s)") # will print 3s + +The possible algorithms are actually nearly limitless, and you may arrange/write you own scheduling technics! + +Async session +------------- + +You may have a program that require ``awaitable`` HTTP request. You are in luck as **Niquests** ships with +an implementation of ``Session`` that support **async**. + +All known methods remain the same at the sole difference that it return a coroutine. + +.. note:: The underlying main library **urllib3.future** does not support native async but is thread safe. This is why we choose to implement / backport `sync_to_async` from Django that use a ThreadPool under the carpet. + +Here is an example:: + + from niquests import AsyncSession + import asyncio + from time import time + + async def emit() -> None: + responses = [] + + async with AsyncSession() as s: # it also work well using multiplexed=True + responses.append(await s.get("https://pie.dev/get")) + responses.append(await s.get("https://pie.dev/delay/3")) + + await s.gather() + + print(responses) + + async def main() -> None: + foo = asyncio.create_task(emit()) + bar = asyncio.create_task(emit()) + await foo + await bar + + if __name__ == "__main__": + before = time() + asyncio.run(main()) + print(time() - before) # 3s! + +.. warning:: For the time being **Niquests** only support **asyncio** as the backend library for async. Contributions are welcomed if you want it to be compatible with **anyio** for example. + +.. note:: Shortcut functions `get`, `post`, ..., from the top-level package does not support async. + ----------------------- Ready for more? Check out the :ref:`advanced ` section. diff --git a/pyproject.toml b/pyproject.toml index 06fe5fb40b..301b92c196 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ description = "Python HTTP for Humans." readme = "README.md" license-files = { paths = ["LICENSE"] } license = "Apache-2.0" -keywords = ["requests", "http/2", "http/3", "QUIC", "http", "https", "http client", "http/1.1", "ocsp", "revocation", "tls"] +keywords = ["requests", "http/2", "http/3", "QUIC", "http", "https", "http client", "http/1.1", "ocsp", "revocation", "tls", "multiplexed"] authors = [ {name = "Kenneth Reitz", email = "me@kennethreitz.org"} ] @@ -41,7 +41,7 @@ dynamic = ["version"] dependencies = [ "charset_normalizer>=2,<4", "idna>=2.5,<4", - "urllib3.future>=2.1.900,<3", + "urllib3.future>=2.2.901,<3", "wassima>=1.0.1,<2", "kiss_headers>=2,<4", ] @@ -51,7 +51,7 @@ socks = [ "PySocks>=1.5.6, !=1.5.7", ] http3 = [ - "qh3<1.0.0,>=0.11.3" + "qh3<1.0.0,>=0.13.0" ] ocsp = [ "cryptography<42.0.0,>=41.0.0" @@ -104,4 +104,5 @@ filterwarnings = [ '''ignore:Passing bytes as a header value is deprecated and will:DeprecationWarning''', '''ignore:The 'JSONIFY_PRETTYPRINT_REGULAR' config key is deprecated and will:DeprecationWarning''', '''ignore:unclosed .*:ResourceWarning''', + '''ignore:Parsed a negative serial number:cryptography.utils.CryptographyDeprecationWarning''', ] diff --git a/requirements-dev.txt b/requirements-dev.txt index bd5849b65e..6f676c7228 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ pytest>=2.8.0,<=7.4.2 pytest-cov pytest-httpbin==2.0.0 +pytest-asyncio>=0.21.1,<1.0 httpbin==0.10.1 trustme wheel diff --git a/src/niquests/__init__.py b/src/niquests/__init__.py index 313a2d4d53..ebcc66a0f0 100644 --- a/src/niquests/__init__.py +++ b/src/niquests/__init__.py @@ -91,6 +91,7 @@ def check_compatibility(urllib3_version: str) -> None: __url__, __version__, ) +from ._async import AsyncSession from .api import delete, get, head, options, patch, post, put, request from .exceptions import ( ConnectionError, @@ -146,4 +147,5 @@ def check_compatibility(urllib3_version: str) -> None: "Response", "Session", "codes", + "AsyncSession", ) diff --git a/src/niquests/__version__.py b/src/niquests/__version__.py index d58fcf5782..ff16bf1c1f 100644 --- a/src/niquests/__version__.py +++ b/src/niquests/__version__.py @@ -9,9 +9,9 @@ __url__: str = "https://niquests.readthedocs.io" __version__: str -__version__ = "3.1.4" +__version__ = "3.2.0" -__build__: int = 0x030104 +__build__: int = 0x030200 __author__: str = "Kenneth Reitz" __author_email__: str = "me@kennethreitz.org" __license__: str = "Apache-2.0" diff --git a/src/niquests/_async.py b/src/niquests/_async.py new file mode 100644 index 0000000000..0c546478ec --- /dev/null +++ b/src/niquests/_async.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import typing + +from ._constant import READ_DEFAULT_TIMEOUT, WRITE_DEFAULT_TIMEOUT +from ._typing import ( + BodyType, + CookiesType, + HeadersType, + HookType, + HttpAuthenticationType, + HttpMethodType, + MultiPartFilesAltType, + MultiPartFilesType, + ProxyType, + QueryParameterType, + TimeoutType, + TLSClientCertType, + TLSVerifyType, +) +from .extensions._sync_to_async import sync_to_async +from .hooks import dispatch_hook +from .models import PreparedRequest, Request, Response +from .sessions import Session + + +class AsyncSession(Session): + """ + "It's aint much, but its honest work" kind of class. + Use a thread pool under the carpet. It's not true async. + """ + + disable_thread: bool = False + + async def __aenter__(self): + return self + + async def __aexit__(self, exc, value, tb): + super().__exit__() + + async def send(self, request: PreparedRequest, **kwargs: typing.Any) -> Response: # type: ignore[override] + return await sync_to_async( + super().send, + thread_sensitive=AsyncSession.disable_thread, + )(request=request, **kwargs) + + async def request( # type: ignore[override] + self, + method: HttpMethodType, + url: str, + params: QueryParameterType | None = None, + data: BodyType | None = None, + headers: HeadersType | None = None, + cookies: CookiesType | None = None, + files: MultiPartFilesType | MultiPartFilesAltType | None = None, + auth: HttpAuthenticationType | None = None, + timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT, + allow_redirects: bool = True, + proxies: ProxyType | None = None, + hooks: HookType[PreparedRequest | Response] | None = None, + stream: bool | None = None, + verify: TLSVerifyType | None = None, + cert: TLSClientCertType | None = None, + json: typing.Any | None = None, + ) -> Response: + if method.isupper() is False: + method = method.upper() + + # Create the Request. + req = Request( + method=method, + url=url, + headers=headers, + files=files, + data=data or {}, + json=json, + params=params or {}, + auth=auth, + cookies=cookies, + hooks=hooks, + ) + + prep: PreparedRequest = dispatch_hook( + "pre_request", hooks, self.prepare_request(req) # type: ignore[arg-type] + ) + + assert prep.url is not None + + proxies = proxies or {} + + settings = self.merge_environment_settings( + prep.url, proxies, stream, verify, cert + ) + + # Send the request. + send_kwargs = { + "timeout": timeout, + "allow_redirects": allow_redirects, + } + send_kwargs.update(settings) + + return await self.send(prep, **send_kwargs) + + async def get( # type: ignore[override] + self, + url: str, + *, + params: QueryParameterType | None = None, + headers: HeadersType | None = None, + cookies: CookiesType | None = None, + auth: HttpAuthenticationType | None = None, + timeout: TimeoutType | None = READ_DEFAULT_TIMEOUT, + allow_redirects: bool = True, + proxies: ProxyType | None = None, + hooks: HookType[PreparedRequest | Response] | None = None, + verify: TLSVerifyType = True, + stream: bool = False, + cert: TLSClientCertType | None = None, + ) -> Response: + return await self.request( + "GET", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + verify=verify, + stream=stream, + cert=cert, + ) + + async def options( # type: ignore[override] + self, + url: str, + *, + params: QueryParameterType | None = None, + headers: HeadersType | None = None, + cookies: CookiesType | None = None, + auth: HttpAuthenticationType | None = None, + timeout: TimeoutType | None = READ_DEFAULT_TIMEOUT, + allow_redirects: bool = True, + proxies: ProxyType | None = None, + hooks: HookType[PreparedRequest | Response] | None = None, + verify: TLSVerifyType = True, + stream: bool = False, + cert: TLSClientCertType | None = None, + ) -> Response: + return await self.request( + "OPTIONS", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + verify=verify, + stream=stream, + cert=cert, + ) + + async def head( # type: ignore[override] + self, + url: str, + *, + params: QueryParameterType | None = None, + headers: HeadersType | None = None, + cookies: CookiesType | None = None, + auth: HttpAuthenticationType | None = None, + timeout: TimeoutType | None = READ_DEFAULT_TIMEOUT, + allow_redirects: bool = True, + proxies: ProxyType | None = None, + hooks: HookType[PreparedRequest | Response] | None = None, + verify: TLSVerifyType = True, + stream: bool = False, + cert: TLSClientCertType | None = None, + ) -> Response: + return await self.request( + "HEAD", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + verify=verify, + stream=stream, + cert=cert, + ) + + async def post( # type: ignore[override] + self, + url: str, + data: BodyType | None = None, + json: typing.Any | None = None, + *, + params: QueryParameterType | None = None, + headers: HeadersType | None = None, + cookies: CookiesType | None = None, + files: MultiPartFilesType | MultiPartFilesAltType | None = None, + auth: HttpAuthenticationType | None = None, + timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT, + allow_redirects: bool = True, + proxies: ProxyType | None = None, + hooks: HookType[PreparedRequest | Response] | None = None, + verify: TLSVerifyType = True, + stream: bool = False, + cert: TLSClientCertType | None = None, + ) -> Response: + return await self.request( + "POST", + url, + data=data, + json=json, + params=params, + headers=headers, + cookies=cookies, + files=files, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + verify=verify, + stream=stream, + cert=cert, + ) + + async def put( # type: ignore[override] + self, + url: str, + data: BodyType | None = None, + *, + json: typing.Any | None = None, + params: QueryParameterType | None = None, + headers: HeadersType | None = None, + cookies: CookiesType | None = None, + files: MultiPartFilesType | MultiPartFilesAltType | None = None, + auth: HttpAuthenticationType | None = None, + timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT, + allow_redirects: bool = True, + proxies: ProxyType | None = None, + hooks: HookType[PreparedRequest | Response] | None = None, + verify: TLSVerifyType = True, + stream: bool = False, + cert: TLSClientCertType | None = None, + ) -> Response: + return await self.request( + "PUT", + url, + data=data, + json=json, + params=params, + headers=headers, + cookies=cookies, + files=files, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + verify=verify, + stream=stream, + cert=cert, + ) + + async def patch( # type: ignore[override] + self, + url: str, + data: BodyType | None = None, + *, + json: typing.Any | None = None, + params: QueryParameterType | None = None, + headers: HeadersType | None = None, + cookies: CookiesType | None = None, + files: MultiPartFilesType | MultiPartFilesAltType | None = None, + auth: HttpAuthenticationType | None = None, + timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT, + allow_redirects: bool = True, + proxies: ProxyType | None = None, + hooks: HookType[PreparedRequest | Response] | None = None, + verify: TLSVerifyType = True, + stream: bool = False, + cert: TLSClientCertType | None = None, + ) -> Response: + return await self.request( + "PATCH", + url, + data=data, + json=json, + params=params, + headers=headers, + cookies=cookies, + files=files, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + verify=verify, + stream=stream, + cert=cert, + ) + + async def delete( # type: ignore[override] + self, + url: str, + *, + params: QueryParameterType | None = None, + headers: HeadersType | None = None, + cookies: CookiesType | None = None, + auth: HttpAuthenticationType | None = None, + timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT, + allow_redirects: bool = True, + proxies: ProxyType | None = None, + hooks: HookType[PreparedRequest | Response] | None = None, + verify: TLSVerifyType = True, + stream: bool = False, + cert: TLSClientCertType | None = None, + ) -> Response: + return await self.request( + "DELETE", + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + proxies=proxies, + hooks=hooks, + verify=verify, + stream=stream, + cert=cert, + ) + + async def gather(self, *responses: Response) -> None: # type: ignore[override] + return await sync_to_async( + super().gather, + thread_sensitive=AsyncSession.disable_thread, + )(*responses) diff --git a/src/niquests/_typing.py b/src/niquests/_typing.py index 4207ad463b..c77648af6a 100644 --- a/src/niquests/_typing.py +++ b/src/niquests/_typing.py @@ -69,6 +69,7 @@ #: Can be a custom authentication mechanism that derive from AuthBase. HttpAuthenticationType: typing.TypeAlias = typing.Union[ typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]], + str, AuthBase, ] #: Map for each protocol (http, https) associated proxy to be used. diff --git a/src/niquests/adapters.py b/src/niquests/adapters.py index de0594404c..bcaa16ded2 100644 --- a/src/niquests/adapters.py +++ b/src/niquests/adapters.py @@ -9,10 +9,21 @@ import os.path import socket # noqa: F401 +import sys +import time import typing +from datetime import timedelta +from http.cookiejar import CookieJar from urllib.parse import urlparse -from urllib3 import HTTPConnectionPool, HTTPSConnectionPool +# Preferred clock, based on which one is more accurate on a given system. +if sys.platform == "win32": + preferred_clock = time.perf_counter +else: + preferred_clock = time.time + +from urllib3 import ConnectionInfo, HTTPConnectionPool, HTTPSConnectionPool +from urllib3.backend import HttpVersion, ResponsePromise from urllib3.exceptions import ClosedPoolError, ConnectTimeoutError from urllib3.exceptions import HTTPError as _HTTPError from urllib3.exceptions import InvalidHeader as _InvalidHeader @@ -39,6 +50,7 @@ ) from ._typing import ( CacheLayerAltSvcType, + HookType, ProxyType, RetryType, TLSClientCertType, @@ -53,11 +65,14 @@ InvalidProxyURL, InvalidSchema, InvalidURL, + MultiplexingError, ProxyError, ReadTimeout, RetryError, SSLError, + TooManyRedirects, ) +from .hooks import dispatch_hook from .models import PreparedRequest, Response from .structures import CaseInsensitiveDict from .utils import ( @@ -68,6 +83,11 @@ urldefragauth, ) +try: + from .extensions._ocsp import verify as ocsp_verify +except ImportError: + ocsp_verify = None # type: ignore[assignment] + try: from urllib3.contrib.socks import SOCKSProxyManager except ImportError: @@ -91,6 +111,7 @@ def send( cert: TLSClientCertType | None = None, proxies: ProxyType | None = None, on_post_connection: typing.Callable[[typing.Any], None] | None = None, + multiplexed: bool = False, ) -> Response: """Sends PreparedRequest object. Returns Response object. @@ -106,6 +127,7 @@ def send( :param proxies: (optional) The proxies dictionary to apply to the request. :param on_post_connection: (optional) A callable that should be invoked just after the pool mgr picked up a live connection. The function is expected to takes one positional argument and return nothing. + :param multiplexed: Determine if we should leverage multiplexed connection. """ raise NotImplementedError @@ -113,6 +135,13 @@ def close(self) -> None: """Cleans up adapter specific items.""" raise NotImplementedError + def gather(self, *responses: Response) -> None: + """ + Load responses that are still 'lazy'. This method is meant for a multiplexed connection. + Implementation is not mandatory. + """ + pass + class HTTPAdapter(BaseAdapter): """The built-in HTTP Adapter for urllib3.future. @@ -148,6 +177,8 @@ class HTTPAdapter(BaseAdapter): "_pool_maxsize", "_pool_block", "_quic_cache_layer", + "_disable_http2", + "_disable_http3", ] def __init__( @@ -157,6 +188,8 @@ def __init__( max_retries: RetryType = DEFAULT_RETRIES, pool_block: bool = DEFAULT_POOLBLOCK, quic_cache_layer: CacheLayerAltSvcType | None = None, + disable_http2: bool = False, + disable_http3: bool = False, ): if isinstance(max_retries, bool): self.max_retries: RetryType = False @@ -180,12 +213,25 @@ def __init__( self._pool_maxsize = pool_maxsize self._pool_block = pool_block self._quic_cache_layer = quic_cache_layer + self._disable_http2 = disable_http2 + self._disable_http3 = disable_http3 + + #: we keep a list of pending (lazy) response + self._promises: list[Response] = [] + + disabled_svn = set() + + if disable_http2: + disabled_svn.add(HttpVersion.h2) + if disable_http3: + disabled_svn.add(HttpVersion.h3) self.init_poolmanager( pool_connections, pool_maxsize, block=pool_block, quic_cache_layer=quic_cache_layer, + disabled_svn=disabled_svn, ) def __getstate__(self) -> dict[str, typing.Any | None]: @@ -200,11 +246,19 @@ def __setstate__(self, state): for attr, value in state.items(): setattr(self, attr, value) + disabled_svn = set() + + if self._disable_http2: + disabled_svn.add(HttpVersion.h2) + if self._disable_http3: + disabled_svn.add(HttpVersion.h3) + self.init_poolmanager( self._pool_connections, self._pool_maxsize, block=self._pool_block, quic_cache_layer=self._quic_cache_layer, + disabled_svn=disabled_svn, ) def init_poolmanager( @@ -252,6 +306,11 @@ def proxy_manager_for(self, proxy: str, **proxy_kwargs: typing.Any) -> ProxyMana :param proxy_kwargs: Extra keyword arguments used to configure the Proxy Manager. :returns: ProxyManager """ + disabled_svn = set() + + if self._disable_http2: + disabled_svn.add(HttpVersion.h2) + if proxy in self.proxy_manager: manager = self.proxy_manager[proxy] elif proxy.lower().startswith("socks"): @@ -263,6 +322,7 @@ def proxy_manager_for(self, proxy: str, **proxy_kwargs: typing.Any) -> ProxyMana num_pools=self._pool_connections, maxsize=self._pool_maxsize, block=self._pool_block, + disabled_svn=disabled_svn, **proxy_kwargs, ) else: @@ -273,6 +333,7 @@ def proxy_manager_for(self, proxy: str, **proxy_kwargs: typing.Any) -> ProxyMana num_pools=self._pool_connections, maxsize=self._pool_maxsize, block=self._pool_block, + disabled_svn=disabled_svn, **proxy_kwargs, ) @@ -336,11 +397,18 @@ def cert_verify( if cert: if not isinstance(cert, str): - conn.cert_file = cert[0] - conn.key_file = cert[1] + if "-----BEGIN CERTIFICATE-----" in cert[0]: + conn.cert_data = cert[0] + conn.key_data = cert[1] + else: + conn.cert_file = cert[0] + conn.key_file = cert[1] conn.key_password = cert[2] if len(cert) == 3 else None # type: ignore[misc] else: - conn.cert_file = cert + if "-----BEGIN CERTIFICATE-----" in cert: + conn.cert_data = cert + else: + conn.cert_file = cert conn.key_file = None conn.key_password = None if conn.cert_file and not os.path.exists(conn.cert_file): @@ -353,40 +421,48 @@ def cert_verify( f"Could not find the TLS key file, invalid path: {conn.key_file}" ) - def build_response(self, req: PreparedRequest, resp: BaseHTTPResponse) -> Response: + def build_response( + self, req: PreparedRequest, resp: BaseHTTPResponse | ResponsePromise + ) -> Response: """Builds a :class:`Response ` object from a urllib3 response. This should not be called from user code, and is only exposed for use when subclassing the :class:`HTTPAdapter ` :param req: The :class:`PreparedRequest ` used to generate the response. - :param resp: The urllib3 response object. + :param resp: The urllib3 response or promise object. """ response = Response() - # Fallback to None if there's no status_code, for whatever reason. - response.status_code = getattr(resp, "status", None) + if isinstance(resp, BaseHTTPResponse): + # Fallback to None if there's no status_code, for whatever reason. + response.status_code = getattr(resp, "status", None) - # Make headers case-insensitive. - response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) + # Make headers case-insensitive. + response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) - # Set encoding. - response.encoding = get_encoding_from_headers(response.headers) - response.raw = resp - response.reason = response.raw.reason + # Set encoding. + response.encoding = get_encoding_from_headers(response.headers) + response.raw = resp + response.reason = response.raw.reason - if isinstance(req.url, bytes): - response.url = req.url.decode("utf-8") - else: - response.url = req.url + if isinstance(req.url, bytes): + response.url = req.url.decode("utf-8") + else: + response.url = req.url - # Add new cookies from the server. - extract_cookies_to_jar(response.cookies, req, resp) + # Add new cookies from the server. + extract_cookies_to_jar(response.cookies, req, resp) + else: + self._promises.append(response) # Give the Response some context. response.request = req response.connection = self # type: ignore[attr-defined] + if isinstance(resp, ResponsePromise): + response._promise = resp + return response def get_connection( @@ -502,6 +578,7 @@ def send( cert: TLSClientCertType | None = None, proxies: ProxyType | None = None, on_post_connection: typing.Callable[[typing.Any], None] | None = None, + multiplexed: bool = False, ) -> Response: """Sends PreparedRequest object. Returns Response object. @@ -516,6 +593,10 @@ def send( (directly) in a string or bytes. :param cert: (optional) Any user-provided SSL certificate to be trusted. :param proxies: (optional) The proxies dictionary to apply to the request. + :param on_post_connection: (optional) A callable that contain a single positional argument for newly acquired + connection. Useful to check acquired connection information. + :param multiplexed: Determine if request shall be transmitted by leveraging the multiplexed aspect of the protocol + if available. Return a lazy instance of Response pending its retrieval. """ assert ( @@ -568,7 +649,7 @@ def send( raise ValueError("Body contains unprepared native list or dict.") try: - resp = conn.urlopen( + resp_or_promise = conn.urlopen( # type: ignore[call-overload] method=request.method, url=url, body=request.body, @@ -581,6 +662,7 @@ def send( timeout=timeout, chunked=chunked, on_post_connection=on_post_connection, + multiplexed=multiplexed, ) except (ProtocolError, OSError) as err: @@ -621,4 +703,235 @@ def send( else: raise - return self.build_response(request, resp) + return self.build_response(request, resp_or_promise) + + def _future_handler(self, response: Response, low_resp: BaseHTTPResponse) -> None: + stream = typing.cast( + bool, response._promise.get_parameter("niquests_is_stream") + ) + start = typing.cast(float, response._promise.get_parameter("niquests_start")) + hooks = typing.cast(HookType, response._promise.get_parameter("niquests_hooks")) + session_cookies = typing.cast( + CookieJar, response._promise.get_parameter("niquests_cookies") + ) + allow_redirects = typing.cast( + bool, response._promise.get_parameter("niquests_allow_redirect") + ) + max_redirect = typing.cast( + int, response._promise.get_parameter("niquests_max_redirects") + ) + redirect_count = typing.cast( + int, response._promise.get_parameter("niquests_redirect_count") + ) + kwargs = typing.cast( + typing.MutableMapping[str, typing.Any], + response._promise.get_parameter("niquests_kwargs"), + ) + + # mark response as "not lazy" anymore by removing ref to "this"/gather. + del response._gather + + req = response.request + assert req is not None + + # Total elapsed time of the request (approximately) + elapsed = preferred_clock() - start + response.elapsed = timedelta(seconds=elapsed) + + # Fallback to None if there's no status_code, for whatever reason. + response.status_code = getattr(low_resp, "status", None) + + # Make headers case-insensitive. + response.headers = CaseInsensitiveDict(getattr(low_resp, "headers", {})) + + # Set encoding. + response.encoding = get_encoding_from_headers(response.headers) + response.raw = low_resp + response.reason = response.raw.reason + + if isinstance(req.url, bytes): + response.url = req.url.decode("utf-8") + else: + response.url = req.url + + # Add new cookies from the server. + extract_cookies_to_jar(response.cookies, req, low_resp) + extract_cookies_to_jar(session_cookies, req, low_resp) + + promise_ctx_backup = { + k: v + for k, v in response._promise._parameters.items() + if k.startswith("niquests_") + } + del response._promise + + if allow_redirects: + next_request = response._resolve_redirect(response, req) + redirect_count += 1 + + if redirect_count > max_redirect + 1: + raise TooManyRedirects( + f"Exceeded {max_redirect} redirects", request=next_request + ) + + if next_request: + session_constructor = promise_ctx_backup["niquests_session_constructor"] + + def on_post_connection(conn_info: ConnectionInfo) -> None: + """This function will be called by urllib3.future just after establishing the connection.""" + nonlocal session_constructor, next_request, kwargs + + assert next_request is not None + next_request.conn_info = conn_info + + if ( + next_request.url + and next_request.url.startswith("https://") + and ocsp_verify is not None + and kwargs["verify"] + ): + strict_ocsp_enabled: bool = ( + os.environ.get("NIQUESTS_STRICT_OCSP", "0") != "0" + ) + + with session_constructor() as ocsp_session: + ocsp_session.trust_env = False + + if not strict_ocsp_enabled: + ocsp_session.proxies = kwargs["proxies"] + + ocsp_verify( + ocsp_session, + next_request, + strict_ocsp_enabled, + 0.2 if not strict_ocsp_enabled else 1.0, + ) + + kwargs["on_post_connection"] = on_post_connection + + next_promise = self.send(next_request, **kwargs) + + next_promise._gather = lambda: self.gather(response) # type: ignore[arg-type] + next_promise._resolve_redirect = response._resolve_redirect + + if "niquests_origin_response" not in promise_ctx_backup: + promise_ctx_backup["niquests_origin_response"] = response + + promise_ctx_backup["niquests_origin_response"].history.append( + next_promise + ) + + promise_ctx_backup["niquests_start"] = preferred_clock() + promise_ctx_backup["niquests_redirect_count"] = redirect_count + + for k, v in promise_ctx_backup.items(): + next_promise._promise.set_parameter(k, v) + + self._promises.remove(response) + + return + else: + response._next = response._resolve_redirect(response, req) # type: ignore[assignment] + + del response._resolve_redirect + + # In case we handled redirects in a multiplexed connection, we shall reorder history + # and do a swap. + if "niquests_origin_response" in promise_ctx_backup: + origin_response: Response = promise_ctx_backup["niquests_origin_response"] + leaf_response: Response = origin_response.history[-1] + + origin_response.history.pop() + + origin_response.status_code, leaf_response.status_code = ( + leaf_response.status_code, + origin_response.status_code, + ) + origin_response.headers, leaf_response.headers = ( + leaf_response.headers, + origin_response.headers, + ) + origin_response.encoding, leaf_response.encoding = ( + leaf_response.encoding, + origin_response.encoding, + ) + origin_response.raw, leaf_response.raw = ( + leaf_response.raw, + origin_response.raw, + ) + origin_response.reason, leaf_response.reason = ( + leaf_response.reason, + origin_response.reason, + ) + origin_response.url, leaf_response.url = ( + leaf_response.url, + origin_response.url, + ) + origin_response.elapsed, leaf_response.elapsed = ( + leaf_response.elapsed, + origin_response.elapsed, + ) + origin_response.request, leaf_response.request = ( + leaf_response.request, + origin_response.request, + ) + + origin_response.history = [leaf_response] + origin_response.history + + # Response manipulation hooks + response = dispatch_hook("response", hooks, response, **kwargs) # type: ignore[arg-type] + + if response.history: + # If the hooks create history then we want those cookies too + for sub_resp in response.history: + extract_cookies_to_jar(session_cookies, sub_resp.request, sub_resp.raw) + + if not stream: + response.content + + self._promises.remove(response) + + def gather(self, *responses: Response) -> None: + if not self._promises: + return + + # Either we did not have a list of promises to fulfill... + if not responses: + while True: + response = None + low_resp = self.poolmanager.get_response() + + if low_resp is None: + break + + for response in self._promises: + if low_resp.is_from_promise(response._promise): + break + + if response is None: + raise MultiplexingError( + "Underlying library yield an unexpected response that did not match any of sent request by us" + ) + + self._future_handler(response, low_resp) + else: + # ...Or we have a list on which we should focus. + for response in responses: + req = response.request + + assert req is not None + + if not hasattr(response, "_promise"): + continue + + low_resp = self.poolmanager.get_response(promise=response._promise) + + if low_resp is None: + raise MultiplexingError( + "Underlying library did not recognize our promise when asked to retrieve it" + ) + + self._future_handler(response, low_resp) + + if self._promises: + self.gather() diff --git a/src/niquests/auth.py b/src/niquests/auth.py index f1c1932cdc..66aa1123d3 100644 --- a/src/niquests/auth.py +++ b/src/niquests/auth.py @@ -46,6 +46,29 @@ def __call__(self, r): raise NotImplementedError("Auth hooks must be callable.") +class BearerTokenAuth(AuthBase): + """Simple token injection in Authorization header""" + + def __init__(self, token: str): + self.token = token + + def __eq__(self, other) -> bool: + return self.token == getattr(other, "token", None) + + def __ne__(self, other) -> bool: + return not self == other + + def __call__(self, r): + detect_token_type: list[str] = self.token.split(" ", maxsplit=1) + + if len(detect_token_type) == 1: + r.headers["Authorization"] = f"Bearer {self.token}" + else: + r.headers["Authorization"] = self.token + + return r + + class HTTPBasicAuth(AuthBase): """Attaches HTTP Basic Authentication to the given Request object.""" diff --git a/src/niquests/exceptions.py b/src/niquests/exceptions.py index 5a6a8dc9a1..d55e2d63ad 100644 --- a/src/niquests/exceptions.py +++ b/src/niquests/exceptions.py @@ -139,6 +139,10 @@ class UnrewindableBodyError(RequestException): """Requests encountered an error when trying to rewind a body.""" +class MultiplexingError(RequestException): + """Requests encountered an unresolvable error in multiplexed mode.""" + + # Warnings diff --git a/src/niquests/extensions/_ocsp.py b/src/niquests/extensions/_ocsp.py index 121736b3da..712485a011 100644 --- a/src/niquests/extensions/_ocsp.py +++ b/src/niquests/extensions/_ocsp.py @@ -61,6 +61,10 @@ def _infer_issuer_from(certificate: Certificate) -> Certificate | None: else: possible_issuer = load_der_x509_certificate(der_cert) + # detect cryptography old build + if not hasattr(certificate, "verify_directly_issued_by"): + break + try: certificate.verify_directly_issued_by(possible_issuer) except ValueError: @@ -402,7 +406,7 @@ def verify( if issuer_certificate is not None: peer_certificate.verify_directly_issued_by(issuer_certificate) - except (socket.gaierror, TimeoutError, ConnectionError): + except (socket.gaierror, TimeoutError, ConnectionError, AttributeError): pass except ValueError: issuer_certificate = None diff --git a/src/niquests/extensions/_sync_to_async.py b/src/niquests/extensions/_sync_to_async.py new file mode 100644 index 0000000000..404f8f911b --- /dev/null +++ b/src/niquests/extensions/_sync_to_async.py @@ -0,0 +1,534 @@ +""" +Copyright (c) Django Software Foundation and individual contributors. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + 3. Neither the name of Django nor the names of its contributors may be used + to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from __future__ import annotations + +import asyncio +import asyncio.coroutines +import contextvars +import functools +import inspect +import os +import sys +import threading +import weakref +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Coroutine, Generic, TypeVar, overload + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing import _type_check + + class _Immutable: + """Mixin to indicate that object should not be copied.""" + + __slots__ = () + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + class ParamSpecArgs(_Immutable): + """The args for a ParamSpec object. + + Given a ParamSpec object P, P.args is an instance of ParamSpecArgs. + + ParamSpecArgs objects have a reference back to their ParamSpec: + + P.args.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.args" + + def __eq__(self, other): + if not isinstance(other, ParamSpecArgs): + return NotImplemented + return self.__origin__ == other.__origin__ + + class ParamSpecKwargs(_Immutable): + """The kwargs for a ParamSpec object. + + Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs. + + ParamSpecKwargs objects have a reference back to their ParamSpec: + + P.kwargs.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.kwargs" + + def __eq__(self, other): + if not isinstance(other, ParamSpecKwargs): + return NotImplemented + return self.__origin__ == other.__origin__ + + def _set_default(type_param, default): + if isinstance(default, (tuple, list)): + type_param.__default__ = tuple( + _type_check(d, "Default must be a type") for d in default + ) + elif default != _marker: + type_param.__default__ = _type_check(default, "Default must be a type") + else: + type_param.__default__ = None + + class _DefaultMixin: + """Mixin for TypeVarLike defaults.""" + + __slots__ = () + __init__ = _set_default + + def _caller(depth=2): + try: + return sys._getframe(depth).f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): # For platforms without _getframe() + return None + + class _Sentinel: + def __repr__(self): + return "" + + _marker = _Sentinel() + + # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + class ParamSpec(list, _DefaultMixin): + """Parameter specification variable. + + Usage:: + + P = ParamSpec('P') + + Parameter specification variables exist primarily for the benefit of static + type checkers. They are used to forward the parameter types of one + callable to another callable, a pattern commonly found in higher order + functions and decorators. They are only valid when used in ``Concatenate``, + or s the first argument to ``Callable``. In Python 3.10 and higher, + they are also supported in user-defined Generics at runtime. + See class Generic for more information on generic types. An + example for annotating a decorator:: + + T = TypeVar('T') + P = ParamSpec('P') + + def add_logging(f: Callable[P, T]) -> Callable[P, T]: + '''A type-safe decorator to add logging to a function.''' + def inner(*args: P.args, **kwargs: P.kwargs) -> T: + logging.info(f'{f.__name__} was called') + return f(*args, **kwargs) + return inner + + @add_logging + def add_two(x: float, y: float) -> float: + '''Add two numbers together.''' + return x + y + + Parameter specification variables defined with covariant=True or + contravariant=True can be used to declare covariant or contravariant + generic types. These keyword arguments are valid, but their actual semantics + are yet to be decided. See PEP 612 for details. + + Parameter specification variables can be introspected. e.g.: + + P.__name__ == 'T' + P.__bound__ == None + P.__covariant__ == False + P.__contravariant__ == False + + Note that only parameter specification variables defined in global scope can + be pickled. + """ + + # Trick Generic __parameters__. + __class__ = TypeVar + + @property + def args(self): + return ParamSpecArgs(self) + + @property + def kwargs(self): + return ParamSpecKwargs(self) + + def __init__( + self, + name, + *, + bound=None, + covariant=False, + contravariant=False, + infer_variance=False, + default=_marker, + ): + super().__init__([self]) + self.__name__ = name + self.__covariant__ = bool(covariant) + self.__contravariant__ = bool(contravariant) + self.__infer_variance__ = bool(infer_variance) + if bound: + self.__bound__ = _type_check(bound, "Bound must be a type.") + else: + self.__bound__ = None + _DefaultMixin.__init__(self, default) + + # for pickling: + def_mod = _caller() + if def_mod != "typing_extensions": + self.__module__ = def_mod + + def __repr__(self): + if self.__infer_variance__: + prefix = "" + elif self.__covariant__: + prefix = "+" + elif self.__contravariant__: + prefix = "-" + else: + prefix = "~" + return prefix + self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + # Hack to get typing._type_check to pass. + def __call__(self, *args, **kwargs): + pass + + +_F = TypeVar("_F", bound=Callable[..., Any]) +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def _restore_context(context: contextvars.Context) -> None: + # Check for changes in contextvars, and set them to the current + # context for downstream consumers + for cvar in context: + cvalue = context.get(cvar) + try: + if cvar.get() != cvalue: + cvar.set(cvalue) + except LookupError: + cvar.set(cvalue) + + +# Python 3.12 deprecates asyncio.iscoroutinefunction() as an alias for +# inspect.iscoroutinefunction(), whilst also removing the _is_coroutine marker. +# The latter is replaced with the inspect.markcoroutinefunction decorator. +# Until 3.12 is the minimum supported Python version, provide a shim. +# Django 4.0 only supports 3.8+, so don't concern with the _or_partial backport. + +if hasattr(inspect, "markcoroutinefunction"): + iscoroutinefunction = inspect.iscoroutinefunction + markcoroutinefunction: Callable[[_F], _F] = inspect.markcoroutinefunction +else: + iscoroutinefunction = asyncio.iscoroutinefunction # type: ignore[assignment] + + def markcoroutinefunction(func: _F) -> _F: + func._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore + return func + + +if sys.version_info >= (3, 8): + _iscoroutinefunction_or_partial = iscoroutinefunction +else: + + def _iscoroutinefunction_or_partial(func: Any) -> bool: + # Python < 3.8 does not correctly determine partially wrapped + # coroutine functions are coroutine functions, hence the need for + # this to exist. Code taken from CPython. + while inspect.ismethod(func): + func = func.__func__ + while isinstance(func, functools.partial): + func = func.func + + return iscoroutinefunction(func) + + +class ThreadSensitiveContext: + """Async context manager to manage context for thread sensitive mode + + This context manager controls which thread pool executor is used when in + thread sensitive mode. By default, a single thread pool executor is shared + within a process. + + In Python 3.7+, the ThreadSensitiveContext() context manager may be used to + specify a thread pool per context. + + This context manager is re-entrant, so only the outer-most call to + ThreadSensitiveContext will set the context. + + Usage: + + >>> import time + >>> async with ThreadSensitiveContext(): + ... await sync_to_async(time.sleep, 1)() + """ + + def __init__(self): + self.token = None + + async def __aenter__(self): + try: + SyncToAsync.thread_sensitive_context.get() + except LookupError: + self.token = SyncToAsync.thread_sensitive_context.set(self) + + return self + + async def __aexit__(self, exc, value, tb): + if not self.token: + return + + executor = SyncToAsync.context_to_thread_executor.pop(self, None) + if executor: + executor.shutdown() + SyncToAsync.thread_sensitive_context.reset(self.token) + + +class SyncToAsync(Generic[_P, _R]): + """ + Utility class which turns a synchronous callable into an awaitable that + runs in a threadpool. It also sets a threadlocal inside the thread so + calls to AsyncToSync can escape it. + + If thread_sensitive is passed, the code will run in the same thread as any + outer code. This is needed for underlying Python code that is not + threadsafe (for example, code which handles SQLite database connections). + + If the outermost program is async (i.e. SyncToAsync is outermost), then + this will be a dedicated single sub-thread that all sync code runs in, + one after the other. If the outermost program is sync (i.e. AsyncToSync is + outermost), this will just be the main thread. This is achieved by idling + with a CurrentThreadExecutor while AsyncToSync is blocking its sync parent, + rather than just blocking. + + If executor is passed in, that will be used instead of the loop's default executor. + In order to pass in an executor, thread_sensitive must be set to False, otherwise + a TypeError will be raised. + """ + + # Storage for main event loop references + threadlocal = threading.local() + + # Single-thread executor for thread-sensitive code + single_thread_executor = ThreadPoolExecutor(max_workers=1) + + # Maintain a contextvar for the current execution context. Optionally used + # for thread sensitive mode. + thread_sensitive_context: contextvars.ContextVar[ + ThreadSensitiveContext + ] = contextvars.ContextVar("thread_sensitive_context") + + # Contextvar that is used to detect if the single thread executor + # would be awaited on while already being used in the same context + deadlock_context: contextvars.ContextVar[bool] = contextvars.ContextVar( + "deadlock_context" + ) + + # Maintaining a weak reference to the context ensures that thread pools are + # erased once the context goes out of scope. This terminates the thread pool. + context_to_thread_executor: weakref.WeakKeyDictionary[ + ThreadSensitiveContext, ThreadPoolExecutor + ] = weakref.WeakKeyDictionary() + + def __init__( + self, + func: Callable[_P, _R], + thread_sensitive: bool = False, + executor: ThreadPoolExecutor | None = None, + ) -> None: + if ( + not callable(func) + or _iscoroutinefunction_or_partial(func) + or _iscoroutinefunction_or_partial(getattr(func, "__call__", func)) + ): + raise TypeError("sync_to_async can only be applied to sync functions.") + self.func = func + functools.update_wrapper(self, func) + self._thread_sensitive = thread_sensitive + markcoroutinefunction(self) + if thread_sensitive and executor is not None: + raise TypeError("executor must not be set when thread_sensitive is True") + self._executor = executor + try: + self.__self__ = func.__self__ # type: ignore + except AttributeError: + pass + + async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + __traceback_hide__ = True # noqa: F841 + loop = asyncio.get_running_loop() + + # Work out what thread to run the code in + if self._thread_sensitive: + if self.thread_sensitive_context.get(None): + # If we have a way of retrieving the current context, attempt + # to use a per-context thread pool executor + thread_sensitive_context = self.thread_sensitive_context.get() + + if thread_sensitive_context in self.context_to_thread_executor: + # Re-use thread executor in current context + executor = self.context_to_thread_executor[thread_sensitive_context] + else: + # Create new thread executor in current context + executor = ThreadPoolExecutor(max_workers=1) + self.context_to_thread_executor[thread_sensitive_context] = executor + elif self.deadlock_context.get(False): + raise RuntimeError( + "Single thread executor already being used, would deadlock" + ) + else: + # Otherwise, we run it in a fixed single thread + executor = self.single_thread_executor + self.deadlock_context.set(True) + else: + # Use the passed in executor, or the loop's default if it is None + executor = self._executor # type: ignore[assignment] + + context = contextvars.copy_context() + child = functools.partial(self.func, *args, **kwargs) + func = context.run + + try: + # Run the code in the right thread + ret: _R = await loop.run_in_executor( + executor, + functools.partial( + self.thread_handler, + loop, + sys.exc_info(), + func, + child, + ), + ) + + finally: + _restore_context(context) + self.deadlock_context.set(False) + + return ret + + def __get__( + self, parent: Any, objtype: Any + ) -> Callable[_P, Coroutine[Any, Any, _R]]: + """ + Include self for methods + """ + func = functools.partial(self.__call__, parent) + return functools.update_wrapper(func, self.func) + + def thread_handler(self, loop, exc_info, func, *args, **kwargs): + """ + Wraps the sync application with exception handling. + """ + + __traceback_hide__ = True # noqa: F841 + + # Set the threadlocal for AsyncToSync + self.threadlocal.main_event_loop = loop + self.threadlocal.main_event_loop_pid = os.getpid() + + # Run the function + # If we have an exception, run the function inside the except block + # after raising it so exc_info is correctly populated. + if exc_info[1]: + try: + raise exc_info[1] + except BaseException: + return func(*args, **kwargs) + else: + return func(*args, **kwargs) + + +@overload +def sync_to_async( + *, + thread_sensitive: bool = True, + executor: ThreadPoolExecutor | None = None, +) -> Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]]: + ... + + +@overload +def sync_to_async( + func: Callable[_P, _R], + *, + thread_sensitive: bool = True, + executor: ThreadPoolExecutor | None = None, +) -> Callable[_P, Coroutine[Any, Any, _R]]: + ... + + +def sync_to_async( + func: Callable[_P, _R] | None = None, + *, + thread_sensitive: bool = False, + executor: ThreadPoolExecutor | None = None, +) -> ( + Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]] + | Callable[_P, Coroutine[Any, Any, _R]] +): + if func is None: + return lambda f: SyncToAsync( + f, + thread_sensitive=thread_sensitive, + executor=executor, + ) + return SyncToAsync( + func, + thread_sensitive=thread_sensitive, + executor=executor, + ) + + +__all__ = ("sync_to_async",) diff --git a/src/niquests/help.py b/src/niquests/help.py index 226ecabe59..77469c4874 100644 --- a/src/niquests/help.py +++ b/src/niquests/help.py @@ -5,6 +5,8 @@ import platform import ssl import sys +import warnings +from json import JSONDecodeError import charset_normalizer import h2 # type: ignore @@ -13,7 +15,9 @@ import urllib3 import wassima -from . import __version__ as requests_version +from . import RequestException +from . import __version__ as niquests_version +from . import get try: import qh3 # type: ignore @@ -97,7 +101,9 @@ def info(): } system_ssl = ssl.OPENSSL_VERSION_NUMBER - system_ssl_info = {"version": f"{system_ssl:x}" if system_ssl is not None else ""} + system_ssl_info = { + "version": f"{system_ssl:x}" if system_ssl is not None else "N/A" + } return { "platform": platform_info, @@ -108,7 +114,7 @@ def info(): "cryptography": cryptography_info, "idna": idna_info, "requests": { - "version": requests_version, + "version": niquests_version, }, "http3": { "enabled": qh3 is not None, @@ -131,6 +137,24 @@ def info(): def main() -> None: """Pretty-print the bug information as JSON.""" + try: + response = get("https://pypi.org/pypi/niquests/json") + package_info = response.json() + + if ( + isinstance(package_info, dict) + and "info" in package_info + and "version" in package_info["info"] + ): + if package_info["info"]["version"] != niquests_version: + warnings.warn( + f"You are using Niquests {niquests_version} and PyPI yield version ({package_info['info']['version']}) as the stable one. " + "We invite you to install this version as soon as possible. Run `python -m pip install niquests -U`.", + UserWarning, + ) + except (RequestException, JSONDecodeError): + pass + print(json.dumps(info(), sort_keys=True, indent=2)) diff --git a/src/niquests/models.py b/src/niquests/models.py index 50b9c5f794..d0020d7e40 100644 --- a/src/niquests/models.py +++ b/src/niquests/models.py @@ -23,7 +23,7 @@ from charset_normalizer import from_bytes from kiss_headers import Headers, parse_it -from urllib3 import BaseHTTPResponse, ConnectionInfo +from urllib3 import BaseHTTPResponse, ConnectionInfo, ResponsePromise from urllib3.exceptions import ( DecodeError, LocationParseError, @@ -49,7 +49,7 @@ MultiPartFilesType, QueryParameterType, ) -from .auth import HTTPBasicAuth +from .auth import BearerTokenAuth, HTTPBasicAuth from .cookies import ( RequestsCookieJar, _copy_cookie_jar, @@ -576,6 +576,8 @@ def prepare_auth(self, auth: HttpAuthenticationType | None, url: str = "") -> No if isinstance(auth, tuple) and len(auth) == 2: # special-case basic HTTP auth auth = HTTPBasicAuth(*auth) + elif isinstance(auth, str): + auth = BearerTokenAuth(auth) if not callable(auth): raise ValueError( @@ -847,6 +849,14 @@ class Response: "request", ] + #: internals used for lazy responses. Do not try to access those unless you know what you are doing. + #: they don't always exist. + _promise: ResponsePromise + _gather: typing.Callable[[], None] + _resolve_redirect: typing.Callable[ + [Response, PreparedRequest], PreparedRequest | None + ] + def __init__(self) -> None: self._content: typing.Literal[False] | bytes | None = False self._content_consumed: bool = False @@ -894,6 +904,54 @@ def __init__(self) -> None: #: is a response. self.request: PreparedRequest | None = None + @property + def lazy(self) -> bool: + """ + Determine if response isn't received and is actually a placeholder. + Only significant if request was sent through a multiplexed connection. + """ + return self.raw is None and hasattr(self, "_gather") + + def __getattribute__(self, item): + if ( + item + not in [ + "_gather", + "lazy", + "request", + "_promise", + "_resolve_redirect", + "__getstate__", + "__setstate__", + "__enter__", + "__exit__", + ] + and item + in Response.__attrs__ + + [ + "json", + "ok", + "links", + "content", + "text", + "iter_lines", + "iter_content", + "next", + "is_redirect", + "is_permanent_redirect", + "status_code", + "cookies", + "reason", + "encoding", + "url", + "headers", + "next", + ] + and self.lazy + ): + self._gather() + return super().__getattribute__(item) + def __enter__(self): return self @@ -903,6 +961,8 @@ def __exit__(self, *args): def __getstate__(self): # Consume everything; accessing the content attribute makes # sure the content has been fully read. + if self.lazy: + self._gather() if not self._content_consumed: self.content @@ -917,16 +977,20 @@ def __setstate__(self, state): setattr(self, "raw", None) def __repr__(self) -> str: - if self.http_version is None: - return "" - - http_revision = self.http_version / 10 + if ( + self.request is None + or self.request.conn_info is None + or self.request.conn_info.http_version is None + ): + return "" # HTTP/2.0 is not preferred, cast it to HTTP/2 instead. - if http_revision.is_integer(): - http_revision = int(http_revision) + http_revision = self.request.conn_info.http_version.value.replace(".0", "") - return f"" + if self.lazy: + return f"" + + return f"" def __bool__(self) -> bool: """Returns True if :attr:`status_code` is less than 400. @@ -1261,11 +1325,11 @@ def http_version(self) -> int | None: """ return self.raw.version if self.raw else None - def raise_for_status(self) -> None: + def raise_for_status(self) -> Response: """Raises :class:`HTTPError`, if one occurred.""" if self.status_code is None: - return + return self http_error_msg = "" if isinstance(self.reason, bytes): @@ -1293,6 +1357,8 @@ def raise_for_status(self) -> None: if http_error_msg: raise HTTPError(http_error_msg, response=self) + return self + def close(self) -> None: """Releases the connection back to the pool. Once this method has been called the underlying ``raw`` object must not be accessed again. diff --git a/src/niquests/sessions.py b/src/niquests/sessions.py index 52c8a7c13d..09c3998de9 100644 --- a/src/niquests/sessions.py +++ b/src/niquests/sessions.py @@ -197,6 +197,7 @@ class Session: "trust_env", "max_redirects", "retries", + "multiplexed", ] def __init__( @@ -204,6 +205,9 @@ def __init__( *, quic_cache_layer: CacheLayerAltSvcType | None = None, retries: RetryType = DEFAULT_RETRIES, + multiplexed: bool = False, + disable_http2: bool = False, + disable_http3: bool = False, ): #: Configured retries for current Session self.retries = retries @@ -233,6 +237,9 @@ def __init__( #: Stream response content default. self.stream = False + #: Toggle to leverage multiplexed connection. + self.multiplexed = multiplexed + #: SSL Verification default. #: Defaults to `True`, requiring requests to verify the TLS certificate at the #: remote end. @@ -276,7 +283,12 @@ def __init__( self.adapters: OrderedDict[str, BaseAdapter] = OrderedDict() self.mount( "https://", - HTTPAdapter(quic_cache_layer=self.quic_cache_layer, max_retries=retries), + HTTPAdapter( + quic_cache_layer=self.quic_cache_layer, + max_retries=retries, + disable_http2=disable_http2, + disable_http3=disable_http3, + ), ) self.mount("http://", HTTPAdapter(max_retries=retries)) @@ -958,8 +970,8 @@ def on_post_connection(conn_info: ConnectionInfo) -> None: ptr_request.conn_info = conn_info if ( - request.url - and request.url.startswith("https://") + ptr_request.url + and ptr_request.url.startswith("https://") and ocsp_verify is not None and kwargs["verify"] ): @@ -985,6 +997,7 @@ def on_post_connection(conn_info: ConnectionInfo) -> None: dispatch_hook("pre_send", hooks, ptr_request) # type: ignore[arg-type] kwargs.setdefault("on_post_connection", on_post_connection) + kwargs.setdefault("multiplexed", self.multiplexed) assert request.url is not None @@ -997,6 +1010,30 @@ def on_post_connection(conn_info: ConnectionInfo) -> None: # Send the request r = adapter.send(request, **kwargs) + # We are leveraging a multiplexed connection + if r.raw is None: + r._gather = lambda: adapter.gather(r) + r._resolve_redirect = lambda x, y: next(self.resolve_redirects(x, y, yield_requests=True, **kwargs), None) # type: ignore[assignment, arg-type] + + # in multiplexed mode, we are unable to forward this local function for safety reasons. + kwargs["on_post_connection"] = None + + # we intentionally set 'niquests' as the prefix. urllib3.future have its own parameters. + r._promise.set_parameter("niquests_is_stream", stream) + r._promise.set_parameter("niquests_start", start) + r._promise.set_parameter("niquests_hooks", hooks) + r._promise.set_parameter("niquests_cookies", self.cookies) + r._promise.set_parameter("niquests_allow_redirect", allow_redirects) + r._promise.set_parameter("niquests_kwargs", kwargs) + r._promise.set_parameter("niquests_session_constructor", Session) + + # You may be wondering why we are setting redirect info in promise ctx. + # because in multiplexed mode, we are not fully aware of hop/redirect count + r._promise.set_parameter("niquests_redirect_count", 0) + r._promise.set_parameter("niquests_max_redirects", self.max_redirects) + + return r + # Total elapsed time of the request (approximately) elapsed = preferred_clock() - start r.elapsed = timedelta(seconds=elapsed) @@ -1050,6 +1087,18 @@ def on_post_connection(conn_info: ConnectionInfo) -> None: return r + def gather(self, *responses: Response) -> None: + """ + Call this method to make sure in-flight responses are retrieved efficiently. This is a no-op + if multiplexed is set to False (which is the default value). + Passing a limited set of responses will wait for given promises and discard others for later. + """ + if self.multiplexed is False: + return + + for adapter in self.adapters.values(): + adapter.gather(*responses) + def merge_environment_settings( self, url: str, diff --git a/tests/conftest.py b/tests/conftest.py index 125bd64c32..24ca249d67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from BaseHTTPServer import HTTPServer from SimpleHTTPServer import SimpleHTTPRequestHandler +import socket import ssl import threading from urllib.parse import urljoin @@ -57,3 +58,25 @@ def nosan_server(tmp_path_factory): server.shutdown() server_thread.join() + + +_WAN_AVAILABLE = None + + +@pytest.fixture(scope="function") +def requires_wan() -> None: + global _WAN_AVAILABLE + + if _WAN_AVAILABLE is not None: + if _WAN_AVAILABLE is False: + pytest.skip("Test requires a WAN access to pie.dev") + return + + try: + sock = socket.create_connection(("pie.dev", 443), timeout=1) + except (ConnectionRefusedError, socket.gaierror, TimeoutError): + _WAN_AVAILABLE = False + pytest.skip("Test requires a WAN access to pie.dev") + else: + _WAN_AVAILABLE = True + sock.close() diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000000..23eba17436 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from niquests import AsyncSession + + +@pytest.mark.usefixtures("requires_wan") +@pytest.mark.asyncio +class TestAsyncWithoutMultiplex: + async def test_awaitable_get(self): + async with AsyncSession() as s: + resp = await s.get("https://pie.dev/get") + + assert resp.lazy is False + assert resp.status_code == 200 + + async def test_concurrent_task_get(self): + async def emit(): + responses = [] + + async with AsyncSession() as s: + responses.append(await s.get("https://pie.dev/get")) + responses.append(await s.get("https://pie.dev/delay/5")) + + return responses + + foo = asyncio.create_task(emit()) + bar = asyncio.create_task(emit()) + + responses_foo = await foo + responses_bar = await bar + + assert len(responses_foo) == 2 + assert len(responses_bar) == 2 + + assert all(r.status_code == 200 for r in responses_foo + responses_bar) + + +@pytest.mark.usefixtures("requires_wan") +@pytest.mark.asyncio +class TestAsyncWithMultiplex: + async def test_awaitable_get(self): + async with AsyncSession(multiplexed=True) as s: + resp = await s.get("https://pie.dev/get") + + assert resp.lazy is True + await s.gather() + + assert resp.status_code == 200 + + async def test_awaitable_get_direct_access_lazy(self): + async with AsyncSession(multiplexed=True) as s: + resp = await s.get("https://pie.dev/get") + + assert resp.lazy is True + assert resp.status_code == 200 + + async def test_concurrent_task_get(self): + async def emit(): + responses = [] + + async with AsyncSession(multiplexed=True) as s: + responses.append(await s.get("https://pie.dev/get")) + responses.append(await s.get("https://pie.dev/delay/5")) + + await s.gather() + + return responses + + foo = asyncio.create_task(emit()) + bar = asyncio.create_task(emit()) + + responses_foo = await foo + responses_bar = await bar + + assert len(responses_foo) == 2 + assert len(responses_bar) == 2 + + assert all(r.status_code == 200 for r in responses_foo + responses_bar) diff --git a/tests/test_multiplexed.py b/tests/test_multiplexed.py new file mode 100644 index 0000000000..2270a794ff --- /dev/null +++ b/tests/test_multiplexed.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import pytest + +from niquests import Session + + +@pytest.mark.usefixtures("requires_wan") +class TestMultiplexed: + def test_concurrent_request_in_sync(self): + responses = [] + + with Session(multiplexed=True) as s: + responses.append(s.get("https://pie.dev/delay/3")) + responses.append(s.get("https://pie.dev/delay/1")) + responses.append(s.get("https://pie.dev/delay/1")) + responses.append(s.get("https://pie.dev/delay/3")) + + assert all(r.lazy for r in responses) + + s.gather() + + assert all(r.lazy is False for r in responses) + assert all(r.status_code == 200 for r in responses) + + def test_redirect_with_multiplexed(self): + with Session(multiplexed=True) as s: + resp = s.get("https://pie.dev/redirect/3") + assert resp.lazy + s.gather() + + assert resp.status_code == 200 + assert resp.url == "https://pie.dev/get" + assert len(resp.history) == 3 + + def test_lazy_access_sync_mode(self): + with Session(multiplexed=True) as s: + resp = s.get("https://pie.dev/headers") + assert resp.lazy + + assert resp.status_code == 200 + + def test_post_data_with_multiplexed(self): + responses = [] + + with Session(multiplexed=True) as s: + for i in range(5): + responses.append( + s.post( + "https://pie.dev/post", + data=b"foo" * 128, + ) + ) + + s.gather() + + assert all(r.lazy is False for r in responses) + assert all(r.status_code == 200 for r in responses) + assert all(r.json()["data"] == "foo" * 128 for r in responses) + + def test_get_stream_with_multiplexed(self): + with Session(multiplexed=True) as s: + resp = s.get("https://pie.dev/headers", stream=True) + assert resp.lazy + + assert resp.status_code == 200 + assert resp._content_consumed is False + + payload = b"" + + for chunk in resp.iter_content(32): + payload += chunk + + assert resp._content_consumed is True + + import json + + assert isinstance(json.loads(payload), dict)