Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better typing (response class and issue with .get, .post, ...) #490

Merged
merged 8 commits into from
Feb 5, 2025
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Union,
cast,
)
from typing_extensions import TypeVar, Generic
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make the typing_extensions an optional dependency.

Copy link
Contributor Author

@novitae novitae Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TypeVar from typing only supports the default argument since python 3.12. Without this argument, the response class hint is broken, that's why I was importing it via typing_extension. I changed that.

from urllib.parse import urlparse

from ..aio import AsyncCurl
Expand All @@ -40,6 +41,8 @@
with suppress(ImportError):
import eventlet.tpool

R = TypeVar('R', bound=Response, default=Response)

if TYPE_CHECKING:
from typing_extensions import Unpack

Expand All @@ -50,7 +53,7 @@ class ProxySpec(TypedDict, total=False):
ws: str
wss: str

class BaseSessionParams(TypedDict, total=False):
class BaseSessionParams(Generic[R], TypedDict, total=False):
headers: Optional[HeaderTypes]
cookies: Optional[CookieTypes]
auth: Optional[Tuple[str, str]]
Expand All @@ -76,7 +79,7 @@ class BaseSessionParams(TypedDict, total=False):
debug: bool
interface: Optional[str]
cert: Optional[Union[str, Tuple[str, str]]]
response_class: Optional[Type[Response]]
response_class: Optional[Type[R]]

class StreamRequestParams(TypedDict, total=False):
params: Optional[Union[Dict, List, Tuple]]
Expand Down Expand Up @@ -109,7 +112,7 @@ class StreamRequestParams(TypedDict, total=False):
max_recv_speed: int
multipart: Optional[CurlMime]

class RequestParams(StreamRequestParams):
class RequestParams(StreamRequestParams, total=False):
stream: Optional[bool]

else:
Expand Down Expand Up @@ -146,7 +149,7 @@ def _peek_aio_queue(q: asyncio.Queue, default=None):
return default


class BaseSession:
class BaseSession(Generic[R]):
"""Provide common methods for setting curl options and reading info in sessions."""

def __init__(
Expand Down Expand Up @@ -177,7 +180,7 @@ def __init__(
debug: bool = False,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
response_class: Optional[Type[Response]] = None,
response_class: Optional[Type[R]] = None,
):
self.headers = Headers(headers)
self._cookies = Cookies(cookies) # guarded by @property
Expand All @@ -202,14 +205,12 @@ def __init__(
self.interface = interface
self.cert = cert

if response_class is None:
response_class = Response
elif not issubclass(response_class, Response):
if response_class is not None and issubclass(response_class, Response) is False:
raise TypeError(
"`response_class` must be a subclass of `curl_cffi.requests.models.Response`"
f" not of type `{response_class}`"
)
self.response_class = response_class
self.response_class = response_class or Response

if proxy and proxies:
raise TypeError("Cannot specify both 'proxy' and 'proxies'")
Expand All @@ -223,9 +224,9 @@ def __init__(

self._closed = False

def _parse_response(self, curl, buffer, header_buffer, default_encoding):
def _parse_response(self, curl, buffer, header_buffer, default_encoding) -> R:
c = curl
rsp = self.response_class(c)
rsp = cast(R, self.response_class(c))
rsp.url = cast(bytes, c.getinfo(CurlInfo.EFFECTIVE_URL)).decode()
if buffer:
rsp.content = buffer.getvalue()
Expand All @@ -235,7 +236,7 @@ def _parse_response(self, curl, buffer, header_buffer, default_encoding):
header_lines = header_buffer.getvalue().splitlines()

# TODO: history urls
header_list = []
header_list: List[bytes] = []
for header_line in header_lines:
if not header_line.strip():
continue
Expand Down Expand Up @@ -285,8 +286,7 @@ def cookies(self, cookies: CookieTypes) -> None:
# This ensures that the cookies property is always converted to Cookies.
self._cookies = Cookies(cookies)


class Session(BaseSession):
class Session(BaseSession[R]):
"""A request session, cookies and connections will be reused. This object is thread-safe,
but it's recommended to use a separate session for each thread."""

Expand All @@ -295,7 +295,7 @@ def __init__(
curl: Optional[Curl] = None,
thread: Optional[ThreadType] = None,
use_thread_local_curl: bool = True,
**kwargs: Unpack[BaseSessionParams],
**kwargs: Unpack[BaseSessionParams[R]],
):
"""
Parameters set in the init method will be overriden by the same parameter in request method.
Expand Down Expand Up @@ -469,7 +469,7 @@ def request(
stream: Optional[bool] = None,
max_recv_speed: int = 0,
multipart: Optional[CurlMime] = None,
) -> Response:
):
"""Send the request, see ``requests.request`` for details on parameters."""

self._check_session_closed()
Expand Down Expand Up @@ -609,7 +609,7 @@ def query(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="QUERY", url=url, **kwargs)


class AsyncSession(BaseSession):
class AsyncSession(BaseSession[R]):
"""An async request session, cookies and connections will be reused."""

def __init__(
Expand All @@ -618,7 +618,7 @@ def __init__(
loop=None,
async_curl: Optional[AsyncCurl] = None,
max_clients: int = 10,
**kwargs: Unpack[BaseSessionParams],
**kwargs: Unpack[BaseSessionParams[R]],
):
"""
Parameters set in the init method will be override by the same parameter in request method.
Expand Down Expand Up @@ -848,7 +848,11 @@ async def ws_connect(
curl.setopt(CurlOpt.CONNECT_ONLY, 2) # https://curl.se/docs/websocket.html

await self.loop.run_in_executor(None, curl.perform)
return AsyncWebSocket(self, curl, autoclose=autoclose)
return AsyncWebSocket(
cast(AsyncSession[Response], self),
curl,
autoclose=autoclose,
)

async def request(
self,
Expand Down
Loading