Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pydantic support in response_format #2647

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/huggingface_hub/_webhooks_payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Contains data structures to parse the webhooks payload."""

from typing import List, Literal, Optional
from typing import Any, List, Literal, Optional, Union

from .utils import is_pydantic_available

Expand All @@ -32,6 +32,34 @@ def __init__(self, *args, **kwargs) -> None:
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
)

@classmethod
def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]:
raise ImportError(
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
)

@classmethod
def schema(cls, *args, **kwargs) -> dict[str, Any]:
raise ImportError(
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
)

@classmethod
def model_validate_json(cls, json_data: Union[str, bytes, bytearray], *args, **kwargs) -> "BaseModel":
raise ImportError(
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
)

@classmethod
def parse_raw(cls, json_data: Union[str, bytes, bytearray], *args, **kwargs) -> "BaseModel":
raise ImportError(
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
)


# This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they
# are not in used anymore. To keep in sync when format is updated in
Expand Down
63 changes: 57 additions & 6 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@
import re
import time
import warnings
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Type, Union, overload

from requests import HTTPError
from requests.structures import CaseInsensitiveDict

from huggingface_hub._webhooks_payload import BaseModel
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub.errors import BadRequestError, InferenceTimeoutError
from huggingface_hub.inference._common import (
Expand Down Expand Up @@ -538,7 +539,7 @@ def chat_completion(
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[ChatCompletionInputGrammarType] = None,
response_format: Optional[Union[ChatCompletionInputGrammarType, Type[BaseModel]]] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
Expand Down Expand Up @@ -590,8 +591,8 @@ def chat_completion(
presence_penalty (`float`, *optional*):
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
text so far, increasing the model's likelihood to talk about new topics.
response_format ([`ChatCompletionInputGrammarType`], *optional*):
Grammar constraints. Can be either a JSONSchema or a regex.
response_format ([`ChatCompletionInputGrammarType`] or `pydantic.BaseModel` class, *optional*):
Grammar constraints. Can be either a JSONSchema, a regex or a Pydantic schema.
seed (Optional[`int`], *optional*):
Seed for reproducible control flow. Defaults to None.
stop (Optional[`str`], *optional*):
Expand Down Expand Up @@ -820,7 +821,7 @@ def chat_completion(
)
```

Example using response_format:
Example using response_format (dict):
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
Expand Down Expand Up @@ -850,7 +851,44 @@ def chat_completion(
>>> response.choices[0].message.content
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
```

Example using response_format (pydantic):
```py
>>> from huggingface_hub import InferenceClient
>>> from pydantic import BaseModel, conint
>>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
>>> messages = [
... {
... "role": "user",
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?",
... },
... ]
>>> class ActivitySummary(BaseModel):
... location: str
... activity: str
... animals_seen: conint(ge=1, le=5)
... animals: list[str]
>>> response = client.chat_completion(
... messages=messages,
... response_format=ActivitySummary,
... max_tokens=500,
)
>>> response.choices[0].message.parsed
ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon'])
```
"""
if isinstance(response_format, type) and issubclass(response_format, BaseModel):
response_model = response_format
# pydantic v2 uses model_json_schema
response_format = ChatCompletionInputGrammarType(
type="json",
value=response_model.model_json_schema()
if hasattr(response_model, "model_json_schema")
else response_model.schema(),
)
else:
response_model = None

model_url = self._resolve_chat_completion_url(model)

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
Expand Down Expand Up @@ -886,7 +924,20 @@ def chat_completion(
if stream:
return _stream_chat_completion_response(data) # type: ignore[arg-type]

return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
if response_model:
for choice in chat_completion_output.choices:
if choice.message.content:
try:
# pydantic v2 uses model_validate_json
choice.message.parsed = (
response_model.model_validate_json(choice.message.content)
if hasattr(response_model, "model_validate_json")
else response_model.parse_raw(choice.message.content)
)
except ValueError:
choice.message.refusal = f"Failed to generate the response as a {response_model.__name__}"
return chat_completion_output

def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
Expand Down
76 changes: 70 additions & 6 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,23 @@
import re
import time
import warnings
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload
from typing import (
TYPE_CHECKING,
Any,
AsyncIterable,
Dict,
List,
Literal,
Optional,
Set,
Type,
Union,
overload,
)

from requests.structures import CaseInsensitiveDict

from huggingface_hub._webhooks_payload import BaseModel
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub.errors import InferenceTimeoutError
from huggingface_hub.inference._common import (
Expand Down Expand Up @@ -574,7 +587,7 @@ async def chat_completion(
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[ChatCompletionInputGrammarType] = None,
response_format: Optional[Union[ChatCompletionInputGrammarType, Type[BaseModel]]] = None,
seed: Optional[int] = None,
stop: Optional[List[str]] = None,
stream_options: Optional[ChatCompletionInputStreamOptions] = None,
Expand Down Expand Up @@ -626,8 +639,8 @@ async def chat_completion(
presence_penalty (`float`, *optional*):
Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the
text so far, increasing the model's likelihood to talk about new topics.
response_format ([`ChatCompletionInputGrammarType`], *optional*):
Grammar constraints. Can be either a JSONSchema or a regex.
response_format ([`ChatCompletionInputGrammarType`] or `pydantic.BaseModel` class, *optional*):
Grammar constraints. Can be either a JSONSchema, a regex or a Pydantic schema.
seed (Optional[`int`], *optional*):
Seed for reproducible control flow. Defaults to None.
stop (Optional[`str`], *optional*):
Expand Down Expand Up @@ -861,7 +874,7 @@ async def chat_completion(
)
```

Example using response_format:
Example using response_format (dict):
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
Expand Down Expand Up @@ -892,7 +905,45 @@ async def chat_completion(
>>> response.choices[0].message.content
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
```

Example using response_format (pydantic):
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> from pydantic import BaseModel, conint
>>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct")
>>> messages = [
... {
... "role": "user",
... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?",
... },
... ]
>>> class ActivitySummary(BaseModel):
... location: str
... activity: str
... animals_seen: conint(ge=1, le=5)
... animals: list[str]
>>> response = await client.chat_completion(
... messages=messages,
... response_format=ActivitySummary,
... max_tokens=500,
)
>>> response.choices[0].message.parsed
ActivitySummary(location='park', activity='bike ride', animals_seen=3, animals=['puppy', 'cat', 'raccoon'])
```
"""
if isinstance(response_format, type) and issubclass(response_format, BaseModel):
response_model = response_format
# pydantic v2 uses model_json_schema
response_format = ChatCompletionInputGrammarType(
type="json",
value=response_model.model_json_schema()
if hasattr(response_model, "model_json_schema")
else response_model.schema(),
)
else:
response_model = None

model_url = self._resolve_chat_completion_url(model)

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
Expand Down Expand Up @@ -928,7 +979,20 @@ async def chat_completion(
if stream:
return _async_stream_chat_completion_response(data) # type: ignore[arg-type]

return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
if response_model:
for choice in chat_completion_output.choices:
if choice.message.content:
try:
# pydantic v2 uses model_validate_json
choice.message.parsed = (
response_model.model_validate_json(choice.message.content)
if hasattr(response_model, "model_validate_json")
else response_model.parse_raw(choice.message.content)
)
except ValueError:
choice.message.refusal = f"Failed to generate the response as a {response_model.__name__}"
return chat_completion_output

def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from dataclasses import dataclass
from typing import Any, List, Literal, Optional, Union

from huggingface_hub._webhooks_payload import BaseModel

from .base import BaseInferenceType


Expand Down Expand Up @@ -196,6 +198,8 @@ class ChatCompletionOutputMessage(BaseInferenceType):
role: str
content: Optional[str] = None
tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
parsed: Optional[BaseModel] = None
refusal: Optional[str] = None


@dataclass
Expand Down
Loading