From aa2e25f9a4a632357051397ea34d269eafba026d Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:39:22 +0000 Subject: [PATCH] chore(internal): add support for parsing bool response content (#1774) --- src/openai/_legacy_response.py | 3 ++ src/openai/_response.py | 3 ++ tests/test_legacy_response.py | 25 +++++++++++++++++ tests/test_response.py | 50 ++++++++++++++++++++++++++++++++++ 4 files changed, 81 insertions(+) diff --git a/src/openai/_legacy_response.py b/src/openai/_legacy_response.py index c7dbd54e23..5260e90bc1 100644 --- a/src/openai/_legacy_response.py +++ b/src/openai/_legacy_response.py @@ -258,6 +258,9 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: if cast_to == float: return cast(R, float(response.text)) + if cast_to == bool: + return cast(R, response.text.lower() == "true") + origin = get_origin(cast_to) or cast_to if inspect.isclass(origin) and issubclass(origin, HttpxBinaryResponseContent): diff --git a/src/openai/_response.py b/src/openai/_response.py index 20ce69ac8a..eac3fbae6c 100644 --- a/src/openai/_response.py +++ b/src/openai/_response.py @@ -192,6 +192,9 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: if cast_to == float: return cast(R, float(response.text)) + if cast_to == bool: + return cast(R, response.text.lower() == "true") + origin = get_origin(cast_to) or cast_to # handle the legacy binary response case diff --git a/tests/test_legacy_response.py b/tests/test_legacy_response.py index f50a77c24d..9da1a80659 100644 --- a/tests/test_legacy_response.py +++ b/tests/test_legacy_response.py @@ -34,6 +34,31 @@ def test_response_parse_mismatched_basemodel(client: OpenAI) -> None: response.parse(to=PydanticModel) +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +def test_response_parse_bool(client: OpenAI, content: str, expected: bool) -> None: + response = LegacyAPIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = response.parse(to=bool) + assert result is expected + + def test_response_parse_custom_stream(client: OpenAI) -> None: response = LegacyAPIResponse( raw=httpx.Response(200, content=b"foo"), diff --git a/tests/test_response.py b/tests/test_response.py index e1fe332f2f..43f24c150d 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -237,6 +237,56 @@ async def test_async_response_parse_annotated_type(async_client: AsyncOpenAI) -> assert obj.bar == 2 +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +def test_response_parse_bool(client: OpenAI, content: str, expected: bool) -> None: + response = APIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = response.parse(to=bool) + assert result is expected + + +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +async def test_async_response_parse_bool(client: AsyncOpenAI, content: str, expected: bool) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = await response.parse(to=bool) + assert result is expected + + class OtherModel(BaseModel): a: str