Skip to content

Commit

Permalink
Revert "Start moving to niquests"
Browse files Browse the repository at this point in the history
It is quiet raw yet: jawah/niquests#182

This reverts commit d51f5ae.
  • Loading branch information
vrslev committed Nov 21, 2024
1 parent d51f5ae commit 95a1913
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 96 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ import any_llm_client

async with any_llm_client.get_client(
...,
httpx_client=niquests.AsyncSession(
httpx_client=httpx.AsyncClient(
mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")},
timeout=httpx.Timeout(None, connect=5.0),
),
Expand Down
52 changes: 20 additions & 32 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from http import HTTPStatus

import annotated_types
import httpx
import httpx_sse
import niquests
import pydantic
import typing_extensions

Expand Down Expand Up @@ -92,27 +92,25 @@ def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn
@dataclasses.dataclass(slots=True, init=False)
class OpenAIClient(LLMClient):
config: OpenAIConfig
httpx_client: niquests.AsyncSession
httpx_client: httpx.AsyncClient
request_retry: RequestRetryConfig

def __init__(
self,
config: OpenAIConfig,
httpx_client: niquests.AsyncSession | None = None,
httpx_client: httpx.AsyncClient | None = None,
request_retry: RequestRetryConfig | None = None,
) -> None:
self.config = config
self.httpx_client = httpx_client or niquests.AsyncSession()
self.httpx_client = httpx_client or httpx.AsyncClient()
self.request_retry = request_retry or RequestRetryConfig()

def _build_request(self, payload: dict[str, typing.Any]) -> niquests.PreparedRequest:
return self.httpx_client.prepare_request(
niquests.Request(
method="POST",
url=str(self.config.url),
json=payload,
headers={"Authorization": f"Bearer {self.config.auth_token}"} if self.config.auth_token else None,
)
def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
method="POST",
url=str(self.config.url),
json=payload,
headers={"Authorization": f"Bearer {self.config.auth_token}"} if self.config.auth_token else None,
)

def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletionsMessage]:
Expand All @@ -139,21 +137,14 @@ async def request_llm_message(self, messages: str | list[Message], temperature:
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except niquests.HTTPError as exception:
if exception.response and exception.response.status_code and exception.response.content:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
else:
raise
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
try:
return (
ChatCompletionsNotStreamingResponse.model_validate_json(response.content) # type: ignore[arg-type]
.choices[0]
.message.content
)
return ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message.content
finally:
await response.aclose()

async def _iter_partial_responses(self, response: niquests.AsyncResponse) -> typing.AsyncIterable[str]:
async def _iter_partial_responses(self, response: httpx.Response) -> typing.AsyncIterable[str]:
text_chunks: typing.Final = []
async for event in httpx_sse.EventSource(response).aiter_sse():
if event.data == "[DONE]":
Expand Down Expand Up @@ -181,16 +172,13 @@ async def stream_llm_partial_messages(
build_request=lambda: self._build_request(payload),
) as response:
yield self._iter_partial_responses(response)
except niquests.HTTPError as exception:
if exception.response and exception.response.status_code and exception.response.content:
content: typing.Final = exception.response.content
exception.response.close()
_handle_status_error(status_code=exception.response.status_code, content=content)
else:
raise
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)

async def __aenter__(self) -> typing_extensions.Self:
await self.httpx_client.__aenter__() # type: ignore[no-untyped-call]
await self.httpx_client.__aenter__()
return self

async def __aexit__(
Expand All @@ -199,4 +187,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.httpx_client.__aexit__(exc_type, exc_value, traceback) # type: ignore[no-untyped-call]
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
48 changes: 21 additions & 27 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from http import HTTPStatus

import annotated_types
import niquests
import httpx
import pydantic
import typing_extensions

Expand Down Expand Up @@ -64,27 +64,25 @@ def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn
@dataclasses.dataclass(slots=True, init=False)
class YandexGPTClient(LLMClient):
config: YandexGPTConfig
httpx_client: niquests.AsyncSession
httpx_client: httpx.AsyncClient
request_retry: RequestRetryConfig

def __init__(
self,
config: YandexGPTConfig,
httpx_client: niquests.AsyncSession | None = None,
httpx_client: httpx.AsyncClient | None = None,
request_retry: RequestRetryConfig | None = None,
) -> None:
self.config = config
self.httpx_client = httpx_client or niquests.AsyncSession()
self.httpx_client = httpx_client or httpx.AsyncClient()
self.request_retry = request_retry or RequestRetryConfig()

def _build_request(self, payload: dict[str, typing.Any]) -> niquests.PreparedRequest:
return self.httpx_client.prepare_request(
niquests.Request(
method="POST",
url=str(self.config.url),
json=payload,
headers={"Authorization": self.config.auth_header, "x-data-logging-enabled": "false"},
)
def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
method="POST",
url=str(self.config.url),
json=payload,
headers={"Authorization": self.config.auth_header, "x-data-logging-enabled": "false"},
)

def _prepare_payload(
Expand All @@ -108,16 +106,13 @@ async def request_llm_message(self, messages: str | list[Message], temperature:
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except niquests.HTTPError as exception:
if exception.response and exception.response.status_code and exception.response.content:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
else:
raise
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)

return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text # type: ignore[arg-type]
return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text

async def _iter_completion_messages(self, response: niquests.AsyncResponse) -> typing.AsyncIterable[str]:
async for one_line in response.iter_lines():
async def _iter_completion_messages(self, response: httpx.Response) -> typing.AsyncIterable[str]:
async for one_line in response.aiter_lines():
validated_response = YandexGPTResponse.model_validate_json(one_line)
yield validated_response.result.alternatives[0].message.text

Expand All @@ -134,14 +129,13 @@ async def stream_llm_partial_messages(
build_request=lambda: self._build_request(payload),
) as response:
yield self._iter_completion_messages(response)
except niquests.HTTPError as exception:
if exception.response and exception.response.status_code and exception.response.content:
content: typing.Final = exception.response.content
exception.response.close()
_handle_status_error(status_code=exception.response.status_code, content=content)
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)

async def __aenter__(self) -> typing_extensions.Self:
await self.httpx_client.__aenter__() # type: ignore[no-untyped-call]
await self.httpx_client.__aenter__()
return self

async def __aexit__(
Expand All @@ -150,4 +144,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.httpx_client.__aexit__(exc_type, exc_value, traceback) # type: ignore[no-untyped-call]
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
24 changes: 11 additions & 13 deletions any_llm_client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@
import typing

import httpx
import niquests
import stamina

from any_llm_client.retry import RequestRetryConfig


async def make_http_request(
*,
httpx_client: niquests.AsyncSession,
httpx_client: httpx.AsyncClient,
request_retry: RequestRetryConfig,
build_request: typing.Callable[[], niquests.PreparedRequest],
) -> niquests.Response:
@stamina.retry(on=niquests.HTTPError, **dataclasses.asdict(request_retry))
async def make_request_with_retries() -> niquests.Response:
build_request: typing.Callable[[], httpx.Request],
) -> httpx.Response:
@stamina.retry(on=httpx.HTTPError, **dataclasses.asdict(request_retry))
async def make_request_with_retries() -> httpx.Response:
response: typing.Final = await httpx_client.send(build_request())
response.raise_for_status()
return response
Expand All @@ -27,19 +26,18 @@ async def make_request_with_retries() -> niquests.Response:
@contextlib.asynccontextmanager
async def make_streaming_http_request(
*,
httpx_client: niquests.AsyncSession,
httpx_client: httpx.AsyncClient,
request_retry: RequestRetryConfig,
build_request: typing.Callable[[], niquests.PreparedRequest],
) -> typing.AsyncIterator[niquests.AsyncResponse]:
build_request: typing.Callable[[], httpx.Request],
) -> typing.AsyncIterator[httpx.Response]:
@stamina.retry(on=httpx.HTTPError, **dataclasses.asdict(request_retry))
async def make_request_with_retries() -> niquests.AsyncResponse:
async def make_request_with_retries() -> httpx.Response:
response: typing.Final = await httpx_client.send(build_request(), stream=True)
response.raise_for_status()
return response # type: ignore[return-value]
return response

response: typing.Final = await make_request_with_retries()
try:
response.__aenter__()
yield response
finally:
await response.close()
await response.aclose()
12 changes: 6 additions & 6 deletions any_llm_client/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import typing

import niquests
import httpx

from any_llm_client.clients.mock import MockLLMClient, MockLLMConfig
from any_llm_client.clients.openai import OpenAIClient, OpenAIConfig
Expand All @@ -18,7 +18,7 @@
def get_client(
config: AnyLLMConfig,
*,
httpx_client: niquests.AsyncSession | None = None,
httpx_client: httpx.AsyncClient | None = None,
request_retry: RequestRetryConfig | None = None,
) -> LLMClient: ... # pragma: no cover
else:
Expand All @@ -27,7 +27,7 @@ def get_client(
def get_client(
config: typing.Any, # noqa: ANN401, ARG001
*,
httpx_client: niquests.AsyncSession | None = None, # noqa: ARG001
httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001
request_retry: RequestRetryConfig | None = None, # noqa: ARG001
) -> LLMClient:
raise AssertionError("unknown LLM config type")
Expand All @@ -36,7 +36,7 @@ def get_client(
def _(
config: YandexGPTConfig,
*,
httpx_client: niquests.AsyncSession | None = None,
httpx_client: httpx.AsyncClient | None = None,
request_retry: RequestRetryConfig | None = None,
) -> LLMClient:
return YandexGPTClient(config=config, httpx_client=httpx_client, request_retry=request_retry)
Expand All @@ -45,7 +45,7 @@ def _(
def _(
config: OpenAIConfig,
*,
httpx_client: niquests.AsyncSession | None = None,
httpx_client: httpx.AsyncClient | None = None,
request_retry: RequestRetryConfig | None = None,
) -> LLMClient:
return OpenAIClient(config=config, httpx_client=httpx_client, request_retry=request_retry)
Expand All @@ -54,7 +54,7 @@ def _(
def _(
config: MockLLMConfig,
*,
httpx_client: niquests.AsyncSession | None = None, # noqa: ARG001
httpx_client: httpx.AsyncClient | None = None, # noqa: ARG001
request_retry: RequestRetryConfig | None = None, # noqa: ARG001
) -> LLMClient:
return MockLLMClient(config=config)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ dependencies = [
"httpx>=0.27.2",
"pydantic>=2.9.2",
"stamina>=24.3.0",
"niquests>=3.11.0",
]
dynamic = ["version"]

Expand Down
13 changes: 6 additions & 7 deletions tests/test_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import faker
import httpx
import niquests
import pydantic
import pytest
from polyfactory.factories.pydantic_factory import ModelFactory
Expand Down Expand Up @@ -37,7 +36,7 @@ async def test_ok(self, faker: faker.Faker) -> None:

result: typing.Final = await any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
).request_llm_message(**LLMFuncRequestFactory.build())

assert result == expected_result
Expand All @@ -49,7 +48,7 @@ async def test_fails_without_alternatives(self) -> None:
)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
)

with pytest.raises(pydantic.ValidationError):
Expand Down Expand Up @@ -92,7 +91,7 @@ async def test_ok(self, faker: faker.Faker) -> None:
)
client: typing.Final = any_llm_client.get_client(
config,
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
)

result: typing.Final = await consume_llm_partial_responses(client.stream_llm_partial_messages(**func_request))
Expand All @@ -108,7 +107,7 @@ async def test_fails_without_alternatives(self) -> None:
)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
)

with pytest.raises(pydantic.ValidationError):
Expand All @@ -121,7 +120,7 @@ class TestOpenAILLMErrors:
async def test_fails_with_unknown_error(self, stream: bool, status_code: int) -> None:
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: httpx.Response(status_code))),
)

coroutine: typing.Final = (
Expand All @@ -146,7 +145,7 @@ async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes
response: typing.Final = httpx.Response(400, content=content)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
httpx_client=niquests.AsyncSession(transport=httpx.MockTransport(lambda _: response)),
httpx_client=httpx.AsyncClient(transport=httpx.MockTransport(lambda _: response)),
)

coroutine: typing.Final = (
Expand Down
4 changes: 2 additions & 2 deletions tests/test_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def test_llm_error_str(faker: faker.Faker) -> None:


def test_llm_func_request_has_same_annotations_as_llm_client_methods() -> None:
all_objects: typing.Final = (
all_objects = (
any_llm_client.LLMClient.request_llm_message,
any_llm_client.LLMClient.stream_llm_partial_messages,
LLMFuncRequest,
)
all_annotations: typing.Final = [typing.get_type_hints(one_object) for one_object in all_objects]
all_annotations = [typing.get_type_hints(one_object) for one_object in all_objects]

for one_ignored_prop in ("return",):
for annotations in all_annotations:
Expand Down
Loading

0 comments on commit 95a1913

Please sign in to comment.