From 161f1e59a9cb536d992c590fc2f8482116c4c523 Mon Sep 17 00:00:00 2001
From: Robert Craigie <robert@craigie.dev>
Date: Mon, 19 Aug 2024 16:20:50 -0400
Subject: [PATCH] fix(json schema): remove `None` defaults

---
 src/openai/lib/_pydantic.py        |  7 ++++
 tests/lib/chat/test_completions.py | 60 +++++++++++++++++++++++++++++-
 tests/lib/schema_types/query.py    |  3 +-
 tests/lib/test_pydantic.py         |  6 ++-
 4 files changed, 72 insertions(+), 4 deletions(-)

diff --git a/src/openai/lib/_pydantic.py b/src/openai/lib/_pydantic.py
index ad3b6eb29f..f989ce3ed0 100644
--- a/src/openai/lib/_pydantic.py
+++ b/src/openai/lib/_pydantic.py
@@ -5,6 +5,7 @@
 
 import pydantic
 
+from .._types import NOT_GIVEN
 from .._utils import is_dict as _is_dict, is_list
 from .._compat import model_json_schema
 
@@ -76,6 +77,12 @@ def _ensure_strict_json_schema(
                 for i, entry in enumerate(all_of)
             ]
 
+    # strip `None` defaults as there's no meaningful distinction here
+    # the schema will still be `nullable` and the model will default
+    # to using `None` anyway
+    if json_schema.get("default", NOT_GIVEN) is None:
+        json_schema.pop("default")
+
     # we can't use `$ref`s if there are also other properties defined, e.g.
     # `{"$ref": "...", "description": "my description"}`
     #
diff --git a/tests/lib/chat/test_completions.py b/tests/lib/chat/test_completions.py
index f003866653..aea449b097 100644
--- a/tests/lib/chat/test_completions.py
+++ b/tests/lib/chat/test_completions.py
@@ -3,7 +3,7 @@
 import os
 import json
 from enum import Enum
-from typing import Any, Callable
+from typing import Any, Callable, Optional
 from typing_extensions import Literal, TypeVar
 
 import httpx
@@ -135,6 +135,63 @@ class Location(BaseModel):
     )
 
 
+@pytest.mark.respx(base_url=base_url)
+def test_parse_pydantic_model_optional_default(
+    client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch
+) -> None:
+    class Location(BaseModel):
+        city: str
+        temperature: float
+        units: Optional[Literal["c", "f"]] = None
+
+    completion = _make_snapshot_request(
+        lambda c: c.beta.chat.completions.parse(
+            model="gpt-4o-2024-08-06",
+            messages=[
+                {
+                    "role": "user",
+                    "content": "What's the weather like in SF?",
+                },
+            ],
+            response_format=Location,
+        ),
+        content_snapshot=snapshot(
+            '{"id": "chatcmpl-9y39Q2jGzWmeEZlm5CoNVOuQzcxP4", "object": "chat.completion", "created": 1724098820, "model": "gpt-4o-2024-08-06", "choices": [{"index": 0, "message": {"role": "assistant", "content": "{\\"city\\":\\"San Francisco\\",\\"temperature\\":62,\\"units\\":\\"f\\"}", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 17, "completion_tokens": 14, "total_tokens": 31}, "system_fingerprint": "fp_2a322c9ffc"}'
+        ),
+        mock_client=client,
+        respx_mock=respx_mock,
+    )
+
+    assert print_obj(completion, monkeypatch) == snapshot(
+        """\
+ParsedChatCompletion[Location](
+    choices=[
+        ParsedChoice[Location](
+            finish_reason='stop',
+            index=0,
+            logprobs=None,
+            message=ParsedChatCompletionMessage[Location](
+                content='{"city":"San Francisco","temperature":62,"units":"f"}',
+                function_call=None,
+                parsed=Location(city='San Francisco', temperature=62.0, units='f'),
+                refusal=None,
+                role='assistant',
+                tool_calls=[]
+            )
+        )
+    ],
+    created=1724098820,
+    id='chatcmpl-9y39Q2jGzWmeEZlm5CoNVOuQzcxP4',
+    model='gpt-4o-2024-08-06',
+    object='chat.completion',
+    service_tier=None,
+    system_fingerprint='fp_2a322c9ffc',
+    usage=CompletionUsage(completion_tokens=14, prompt_tokens=17, total_tokens=31)
+)
+"""
+    )
+
+
 @pytest.mark.respx(base_url=base_url)
 def test_parse_pydantic_model_enum(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
     class Color(Enum):
@@ -320,6 +377,7 @@ def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, m
                                 value=DynamicValue(column_name='expected_delivery_date')
                             )
                         ],
+                        name=None,
                         order_by=<OrderBy.asc: 'asc'>,
                         table_name=<Table.orders: 'orders'>
                     )
diff --git a/tests/lib/schema_types/query.py b/tests/lib/schema_types/query.py
index d2284424f0..03439fb17f 100644
--- a/tests/lib/schema_types/query.py
+++ b/tests/lib/schema_types/query.py
@@ -1,5 +1,5 @@
 from enum import Enum
-from typing import List, Union
+from typing import List, Union, Optional
 
 from pydantic import BaseModel
 
@@ -45,6 +45,7 @@ class Condition(BaseModel):
 
 
 class Query(BaseModel):
+    name: Optional[str] = None
     table_name: Table
     columns: List[Column]
     conditions: List[Condition]
diff --git a/tests/lib/test_pydantic.py b/tests/lib/test_pydantic.py
index 531a89df58..99b9e96d21 100644
--- a/tests/lib/test_pydantic.py
+++ b/tests/lib/test_pydantic.py
@@ -62,6 +62,7 @@ def test_most_types() -> None:
                         "Table": {"enum": ["orders", "customers", "products"], "title": "Table", "type": "string"},
                     },
                     "properties": {
+                        "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"},
                         "table_name": {"$ref": "#/$defs/Table"},
                         "columns": {
                             "items": {"$ref": "#/$defs/Column"},
@@ -75,7 +76,7 @@ def test_most_types() -> None:
                         },
                         "order_by": {"$ref": "#/$defs/OrderBy"},
                     },
-                    "required": ["table_name", "columns", "conditions", "order_by"],
+                    "required": ["name", "table_name", "columns", "conditions", "order_by"],
                     "title": "Query",
                     "type": "object",
                     "additionalProperties": False,
@@ -91,6 +92,7 @@ def test_most_types() -> None:
                     "title": "Query",
                     "type": "object",
                     "properties": {
+                        "name": {"title": "Name", "type": "string"},
                         "table_name": {"$ref": "#/definitions/Table"},
                         "columns": {"type": "array", "items": {"$ref": "#/definitions/Column"}},
                         "conditions": {
@@ -100,7 +102,7 @@ def test_most_types() -> None:
                         },
                         "order_by": {"$ref": "#/definitions/OrderBy"},
                     },
-                    "required": ["table_name", "columns", "conditions", "order_by"],
+                    "required": ["name", "table_name", "columns", "conditions", "order_by"],
                     "definitions": {
                         "Table": {
                             "title": "Table",