Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix InferenceAPI on image task #1270

Merged
merged 3 commits into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/source/package_reference/inference_api.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@

The `huggingface_hub` library allows users to programmatically access the Inference API. For more information about the Accelerated Inference API, please refer to the documentation [here](https://huggingface.co/docs/api-inference/index).

[[autodoc]] InferenceApi
[[autodoc]] InferenceApi
- __init__
- __call__
- all
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get_version() -> str:
"pytest-cov",
"pytest-env",
"soundfile",
"Pillow",
]

# Typing extra dependencies list is duplicated in `.pre-commit-config.yaml`
Expand Down
88 changes: 70 additions & 18 deletions src/huggingface_hub/inference_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import io
from typing import Any, Dict, List, Optional, Union

import requests

from .hf_api import HfApi
from .utils import logging, validate_hf_hub_args
from .utils import build_hf_headers, is_pillow_available, logging, validate_hf_hub_args


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -74,6 +75,19 @@ class InferenceApi:

>>> # Overriding configured task
>>> inference = InferenceApi("bert-base-uncased", task="feature-extraction")

>>> # Text-to-image
>>> inference = InferenceApi("stabilityai/stable-diffusion-2-1")
>>> inference("cat")
<PIL.PngImagePlugin.PngImageFile image (...)>

>>> # Return as raw response to parse the output yourself
>>> inference = InferenceApi("mio/amadeus")
>>> response = inference("hello world", raw_response=True)
>>> response.headers
{"Content-Type": "audio/flac", ...}
>>> response.content # raw bytes from server
b'(...)'
```
"""

Expand All @@ -99,20 +113,15 @@ def __init__(
https://huggingface.co/settings/token. Alternatively, you can
find both your organizations and personal API tokens using
`HfApi().whoami(token)`.
gpu (``bool``, `optional`, defaults ``False``):
gpu (`bool`, `optional`, defaults `False`):
Whether to use GPU instead of CPU for inference(requires Startup
plan at least).
.. note::
Setting `token` is required when you want to use a private model.
"""
self.options = {"wait_for_model": True, "use_gpu": gpu}

self.headers = {}
if isinstance(token, str):
self.headers["Authorization"] = f"Bearer {token}"
self.headers = build_hf_headers(token=token)

# Configure task
model_info = HfApi().model_info(repo_id=repo_id, token=token)
model_info = HfApi(token=token).model_info(repo_id=repo_id)
if not model_info.pipeline_tag and not task:
raise ValueError(
"Task not specified in the repository. Please add it to the model card"
Expand All @@ -136,28 +145,71 @@ def __init__(
self.api_url = f"{ENDPOINT}/pipeline/{self.task}/{repo_id}"

def __repr__(self):
items = (f"{k}='{v}'" for k, v in self.__dict__.items())
return f"{self.__class__.__name__}({', '.join(items)})"
# Do not add headers to repr to avoid leaking token.
return (
f"InferenceAPI(api_url='{self.api_url}', task='{self.task}',"
f" options={self.options})"
)

def __call__(
self,
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
params: Optional[Dict] = None,
data: Optional[bytes] = None,
):
raw_response: bool = False,
osanseviero marked this conversation as resolved.
Show resolved Hide resolved
) -> Any:
"""Make a call to the Inference API.

Args:
inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*):
Inputs for the prediction.
params (`Dict`, *optional*):
Additional parameters for the models. Will be sent as `parameters` in the
payload.
data (`bytes`, *optional*):
Bytes content of the request. In this case, leave `inputs` and `params` empty.
raw_response (`bool`, defaults to `False`):
If `True`, the raw `Response` object is returned. You can parse its content
as preferred. By default, the content is parsed into a more practical format
(json dictionary or PIL Image for example).
"""
# Build payload
payload: Dict[str, Any] = {
"options": self.options,
}

if inputs:
payload["inputs"] = inputs

if params:
payload["parameters"] = params

# TODO: Decide if we should raise an error instead of
# returning the json.
# Make API call
response = requests.post(
self.api_url, headers=self.headers, json=payload, data=data
).json()
return response
)

# Let the user handle the response
if raw_response:
return response

# By default, parse the response for the user.
content_type = response.headers.get("Content-Type") or ""
if content_type.startswith("image"):
if not is_pillow_available():
raise ImportError(
f"Task '{self.task}' returned as image but Pillow is not installed."
" Please install it (`pip install Pillow`) or pass"
" `raw_response=True` to get the raw `Response` object and parse"
" the image by yourself."
)

from PIL import Image

return Image.open(io.BytesIO(response.content))
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
elif content_type == "application/json":
return response.json()
else:
raise NotImplementedError(
f"{content_type} output type is not implemented yet. You can pass"
" `raw_response=True` to get the raw `Response` object and parse the"
" output by yourself."
)
2 changes: 2 additions & 0 deletions src/huggingface_hub/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
get_graphviz_version,
get_hf_hub_version,
get_jinja_version,
get_pillow_version,
get_pydot_version,
get_python_version,
get_tf_version,
Expand All @@ -69,6 +70,7 @@
is_graphviz_available,
is_jinja_available,
is_notebook,
is_pillow_available,
is_pydot_available,
is_tf_available,
is_torch_available,
Expand Down
11 changes: 11 additions & 0 deletions src/huggingface_hub/utils/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"fastai": {"fastai"},
"fastcore": {"fastcore"},
"jinja": {"Jinja2"},
"pillow": {"Pillow"},
}

# Check once at runtime
Expand Down Expand Up @@ -118,6 +119,15 @@ def get_jinja_version() -> str:
return _get_version("jinja")


# Pillow
def is_pillow_available() -> bool:
return _is_available("pillow")


def get_pillow_version() -> str:
return _get_version("pillow")


# Pydot
def is_pydot_available() -> bool:
return _is_available("pydot")
Expand Down Expand Up @@ -232,6 +242,7 @@ def dump_environment_info() -> Dict[str, Any]:
info["Jinja2"] = get_jinja_version()
info["Graphviz"] = get_graphviz_version()
info["Pydot"] = get_pydot_version()
info["Pillow"] = get_pillow_version()

print("\nCopy-and-paste the text below in your GitHub issue.\n")
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n")
Expand Down
45 changes: 31 additions & 14 deletions tests/test_inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from unittest.mock import patch

from PIL import Image

from huggingface_hub import hf_hub_download
from huggingface_hub.inference_api import InferenceApi

from .testing_utils import with_production_testing


@with_production_testing
class InferenceApiTest(unittest.TestCase):
def read(self, filename: str) -> bytes:
with open(filename, "rb") as f:
bpayload = f.read()
return bpayload
return Path(filename).read_bytes()

@classmethod
@with_production_testing
def setUpClass(cls) -> None:
cls.image_file = hf_hub_download(
repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png"
)
return super().setUpClass()

def test_simple_inference(self):
api = InferenceApi("bert-base-uncased")
inputs = "Hi, I think [MASK] is cool"
Expand All @@ -37,7 +47,6 @@ def test_simple_inference(self):
self.assertTrue("sequence" in result)
self.assertTrue("score" in result)

@with_production_testing
def test_inference_with_params(self):
api = InferenceApi("typeform/distilbert-base-uncased-mnli")
inputs = (
Expand All @@ -50,7 +59,6 @@ def test_inference_with_params(self):
self.assertTrue("sequence" in result)
self.assertTrue("scores" in result)

@with_production_testing
def test_inference_with_dict_inputs(self):
api = InferenceApi("distilbert-base-cased-distilled-squad")
inputs = {
Expand All @@ -62,7 +70,6 @@ def test_inference_with_dict_inputs(self):
self.assertTrue("score" in result)
self.assertTrue("answer" in result)

@with_production_testing
def test_inference_with_audio(self):
api = InferenceApi("facebook/wav2vec2-base-960h")
file = hf_hub_download(
Expand All @@ -75,21 +82,33 @@ def test_inference_with_audio(self):
self.assertIsInstance(result, dict)
self.assertTrue("text" in result, f"We received {result} instead")

@with_production_testing
def test_inference_with_image(self):
api = InferenceApi("google/vit-base-patch16-224")
file = hf_hub_download(
repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png"
)
data = self.read(file)
data = self.read(self.image_file)
result = api(data=data)
self.assertIsInstance(result, list)
for classification in result:
self.assertIsInstance(classification, dict)
self.assertTrue("score" in classification)
self.assertTrue("label" in classification)

@with_production_testing
def test_text_to_image(self):
api = InferenceApi("stabilityai/stable-diffusion-2-1")
with patch("huggingface_hub.inference_api.requests") as mock:
mock.post.return_value.headers = {"Content-Type": "image/jpeg"}
mock.post.return_value.content = self.read(self.image_file)
output = api("cat")
self.assertIsInstance(output, Image.Image)

def test_text_to_image_raw_response(self):
api = InferenceApi("stabilityai/stable-diffusion-2-1")
with patch("huggingface_hub.inference_api.requests") as mock:
mock.post.return_value.headers = {"Content-Type": "image/jpeg"}
mock.post.return_value.content = self.read(self.image_file)
output = api("cat", raw_response=True)
# Raw response is returned
self.assertEqual(output, mock.post.return_value)

def test_inference_overriding_task(self):
api = InferenceApi(
"sentence-transformers/paraphrase-albert-small-v2",
Expand All @@ -99,14 +118,12 @@ def test_inference_overriding_task(self):
result = api(inputs)
self.assertIsInstance(result, list)

@with_production_testing
def test_inference_overriding_invalid_task(self):
with self.assertRaises(
ValueError, msg="Invalid task invalid-task. Make sure it's valid."
):
InferenceApi("bert-base-uncased", task="invalid-task")

@with_production_testing
def test_inference_missing_input(self):
api = InferenceApi("deepset/roberta-base-squad2")
result = api({"question": "What's my name?"})
Expand Down