Skip to content

Commit

Permalink
Custom headers/cookies in InferenceClient (#1507)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored Jun 14, 2023
1 parent b4e01b0 commit e8de966
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/guides/inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ You can catch it and handle it in your code:
Some tasks require binary inputs, for example, when dealing with images or audio files. In this case, [`InferenceClient`]
tries to be as permissive as possible and accept different types:
- raw `bytes`
- a file-like object, opened as binary (`with open("audio.wav", "rb") as f: ...`)
- a file-like object, opened as binary (`with open("audio.flac", "rb") as f: ...`)
- a path (`str` or `Path`) pointing to a local file
- a URL (`str`) pointing to a remote file (e.g. `https://...`). In this case, the file will be downloaded locally before
sending it to the Inference API.
Expand Down
34 changes: 27 additions & 7 deletions src/huggingface_hub/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from typing import TYPE_CHECKING, Any, BinaryIO, ContextManager, Dict, Generator, List, Optional, Union, overload

from requests import HTTPError, Response
from requests.structures import CaseInsensitiveDict

from ._inference_types import ClassificationOutput, ConversationalOutput, ImageSegmentationOutput
from .constants import INFERENCE_ENDPOINT
Expand Down Expand Up @@ -96,17 +97,31 @@ class InferenceClient:
or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is
automatically selected for the task.
token (`str`, *optional*):
Hugging Face token. Will default to the locally saved token.
Hugging Face token. Will default to the locally saved token. Pass `token=False` if you don't want to send
your token to the server.
timeout (`float`, `optional`):
The maximum number of seconds to wait for a response from the server. Loading a new model in Inference
API can take up to several minutes. Defaults to None, meaning it will loop until the server is available.
headers (`Dict[str, str]`, `optional`):
Additional headers to send to the server. By default only the authorization and user-agent headers are sent.
Values in this dictionary will override the default values.
cookies (`Dict[str, str]`, `optional`):
Additional cookies to send to the server.
"""

def __init__(
self, model: Optional[str] = None, token: Optional[str] = None, timeout: Optional[float] = None
self,
model: Optional[str] = None,
token: Union[str, bool, None] = None,
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
) -> None:
self.model: Optional[str] = model
self.headers = build_hf_headers(token=token)
self.headers = CaseInsensitiveDict(build_hf_headers(token=token)) # contains 'authorization' + 'user-agent'
if headers is not None:
self.headers.update(headers)
self.cookies = cookies
self.timeout = timeout

def __repr__(self):
Expand Down Expand Up @@ -157,7 +172,12 @@ def post(
with _open_as_binary(data) as data_as_binary:
try:
response = get_session().post(
url, json=json, data=data_as_binary, headers=self.headers, timeout=self.timeout
url,
json=json,
data=data_as_binary,
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)
except TimeoutError as error:
# Convert any `TimeoutError` to a `InferenceTimeoutError`
Expand Down Expand Up @@ -214,7 +234,7 @@ def audio_classification(
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.audio_classification("audio.wav")
>>> client.audio_classification("audio.flac")
[{'score': 0.4976358711719513, 'label': 'hap'}, {'score': 0.3677836060523987, 'label': 'neu'},...]
```
"""
Expand Down Expand Up @@ -250,7 +270,7 @@ def automatic_speech_recognition(
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.automatic_speech_recognition("hello_world.wav")
>>> client.automatic_speech_recognition("hello_world.flac")
"hello world"
```
"""
Expand Down Expand Up @@ -760,7 +780,7 @@ def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes:
>>> client = InferenceClient()
>>> audio = client.text_to_speech("Hello world")
>>> Path("hello_world.wav").write_bytes(audio)
>>> Path("hello_world.flac").write_bytes(audio)
```
"""
response = self.post(json={"inputs": text}, model=model, task="text-to-speech")
Expand Down
36 changes: 36 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import io
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
from PIL import Image

from huggingface_hub import InferenceClient, hf_hub_download
from huggingface_hub._inference import _open_as_binary
from huggingface_hub.utils import build_hf_headers

from .testing_utils import with_production_testing

Expand Down Expand Up @@ -251,3 +253,37 @@ def test_recommended_model_from_supported_task(self) -> None:
def test_unsupported_task(self) -> None:
with self.assertRaises(NotImplementedError):
InferenceClient()._resolve_url(task="unknown-task")


class TestHeadersAndCookies(unittest.TestCase):
def test_headers_and_cookies(self) -> None:
client = InferenceClient(headers={"X-My-Header": "foo"}, cookies={"my-cookie": "bar"})
self.assertEqual(client.headers["X-My-Header"], "foo")
self.assertEqual(client.cookies["my-cookie"], "bar")

def test_headers_overwrite(self) -> None:
# Default user agent
self.assertTrue(InferenceClient().headers["user-agent"].startswith("unknown/None;"))

# Overwritten user-agent
self.assertEqual(InferenceClient(headers={"user-agent": "bar"}).headers["user-agent"], "bar")

# Case-insensitive overwrite
self.assertEqual(InferenceClient(headers={"USER-agent": "bar"}).headers["user-agent"], "bar")

@patch("huggingface_hub._inference.get_session")
def test_mocked_post(self, get_session_mock: MagicMock) -> None:
"""Test that headers and cookies are correctly passed to the request."""
client = InferenceClient(headers={"X-My-Header": "foo"}, cookies={"my-cookie": "bar"})
response = client.post(data=b"content", model="username/repo_name")
self.assertEqual(response, get_session_mock().post.return_value)

expected_user_agent = build_hf_headers()["user-agent"]
get_session_mock().post.assert_called_once_with(
"https://api-inference.huggingface.co/models/username/repo_name",
json=None,
data=b"content",
headers={"user-agent": expected_user_agent, "X-My-Header": "foo"},
cookies={"my-cookie": "bar"},
timeout=None,
)

0 comments on commit e8de966

Please sign in to comment.