Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added zero shot image classification #1528

Merged
merged 6 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions docs/source/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ What if you want to use a specific model? You can specify it either as a paramet
# Initialize client for a specific model
>>> client = InferenceClient(model="prompthero/openjourney-v4")
>>> client.text_to_image(...)
# Or use a generic client but pass your model as an argument
# Or use a generic client but pass your model as an argument
>>> client = InferenceClient()
>>> client.text_to_image(..., model="prompthero/openjourney-v4")
```
Expand All @@ -88,7 +88,7 @@ code as before, changing only the `model` parameter:
```python
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient(model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if")
# or
# or
>>> client = InferenceClient()
>>> client.text_to_image(..., model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if")
```
Expand Down Expand Up @@ -129,6 +129,8 @@ has a simple API that supports the most common tasks. Here is a list of the curr
| | [Image-to-Text](https://huggingface.co/tasks/image-to-text) | ✅ | [`~InferenceClient.image_to_text`] |
| | [Object Detection](https://huggingface.co/tasks/object-detection) | | |
| | [Text-to-Image](https://huggingface.co/tasks/text-to-image) | ✅ | [`~InferenceClient.text_to_image`] |
| | [Zero-Shot-Image-Classification](https://huggingface.co/tasks/zero-shot-image-classification) | ✅ | [`~InferenceClient.zero_shot_image_classification`] |

| Multimodal | [Documentation Question Answering](https://huggingface.co/tasks/document-question-answering) | | |
| | [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | | |
| NLP | [Conversational](https://huggingface.co/tasks/conversational) | ✅ | [`~InferenceClient.conversational`] |
Expand Down Expand Up @@ -223,7 +225,7 @@ Here is a short guide to help you migrate from [`InferenceApi`] to [`InferenceCl

### Initialization

Change from
Change from

```python
>>> from huggingface_hub import InferenceApi
Expand All @@ -239,7 +241,7 @@ to

### Run on a specific task

Change from
Change from

```python
>>> from huggingface_hub import InferenceApi
Expand Down Expand Up @@ -306,4 +308,4 @@ to
>>> response = client.post(json={"inputs": inputs, "parameters": params}, model="typeform/distilbert-base-uncased-mnli")
>>> response.json()
{'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]}
```
```
47 changes: 47 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,53 @@ def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes:
response = self.post(json={"inputs": text}, model=model, task="text-to-speech")
return response.content

def zero_shot_image_classification(
self, image: ContentT, labels: List[str], *, model: Optional[str] = None
) -> List[ClassificationOutput]:
"""
Provide input image and text labels to predict text labels for the image.

Args:
image (`Union[str, Path, bytes, BinaryIO]`):
The input image to caption. It can be raw bytes, an image file, or a URL to an online image.
labels (`List[str]`):
List of string possible labels. The `len(labels)` must be greater than 1.
model (`str`, *optional*):
The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None.

Returns:
`List[Dict]`: List of classification outputs containing the predicted labels and their confidence.

Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.

Example:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()

>>> client.zero_shot_image_classification(
... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg",
... labels=["dog", "cat", "horse"],
... )
[{"label": "dog", "score": 0.956}, ...]
"""

# Raise valueerror if input is less than 2 labels
if len(labels) < 2:
raise ValueError("You must specify at least 2 classes to compare. Please specify more than 1 class.")

response = self.post(
json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}},
model=model,
task="zero-shot-image-classification",
)
return response.json()

def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
model = model or self.model

Expand Down
2 changes: 2 additions & 0 deletions src/huggingface_hub/utils/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(self, message: str, response: Optional[Response] = None):
if server_message_from_headers is not None: # from headers
_server_message += server_message_from_headers + "\n"
if server_message_from_body is not None: # from body "error"
if isinstance(server_message_from_body, list):
server_message_from_body = "\n".join(server_message_from_body)
if server_message_from_body not in _server_message:
_server_message += server_message_from_body + "\n"
if server_multiple_messages_from_body is not None: # from body "errors"
Expand Down

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"text-classification": "distilbert-base-uncased-finetuned-sst-2-english",
"text-to-image": "CompVis/stable-diffusion-v1-4",
"text-to-speech": "espnet/kan-bayashi_ljspeech_vits",
"zero-shot-image-classification": "openai/clip-vit-base-patch32",
}


Expand Down Expand Up @@ -197,6 +198,14 @@ def test_text_to_speech(self) -> None:
audio = self.client.text_to_speech("Hello world")
self.assertIsInstance(audio, bytes)

def test_zero_shot_image_classification(self) -> None:
output = self.client.zero_shot_image_classification(self.image_file, ["tree", "woman", "cat"])
self.assertIsInstance(output, list)
self.assertGreater(len(output), 0)
for item in output:
self.assertIsInstance(item["label"], str)
self.assertIsInstance(item["score"], float)


class TestOpenAsBinary(InferenceClientTest):
def test_open_as_binary_with_none(self) -> None:
Expand Down