Skip to content

Commit

Permalink
Merge branch 'main' of github.com:huggingface/huggingface_hub
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Sep 7, 2023
2 parents 8cf4cb6 + c2a7b8a commit 816fc20
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ has a simple API that supports the most common tasks. Here is a list of the curr
| | [Object Detection](https://huggingface.co/tasks/object-detection) || [`~InferenceClient.object_detection`] |
| | [Text-to-Image](https://huggingface.co/tasks/text-to-image) || [`~InferenceClient.text_to_image`] |
| | [Zero-Shot-Image-Classification](https://huggingface.co/tasks/zero-shot-image-classification) || [`~InferenceClient.zero_shot_image_classification`] |
| Multimodal | [Documentation Question Answering](https://huggingface.co/tasks/document-question-answering) | | |
| Multimodal | [Documentation Question Answering](https://huggingface.co/tasks/document-question-answering) | ✅ | [`~InferenceClient.document_question_answering`]
| | [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) || [`~InferenceClient.visual_question_answering`] |
| NLP | [Conversational](https://huggingface.co/tasks/conversational) || [`~InferenceClient.conversational`] |
| | [Feature Extraction](https://huggingface.co/tasks/feature-extraction) || [`~InferenceClient.feature_extraction`] |
Expand Down
41 changes: 41 additions & 0 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,47 @@ def visual_question_answering(
response = self.post(json=payload, model=model, task="visual-question-answering")
return _bytes_to_list(response)

def document_question_answering(
self,
image: ContentT,
question: str,
*,
model: Optional[str] = None,
) -> List[QuestionAnsweringOutput]:
"""
Answer questions on document images.
Args:
image (`Union[str, Path, bytes, BinaryIO]`):
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
question (`str`):
Question to be answered.
model (`str`, *optional*):
The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used.
Defaults to None.
Returns:
`List[Dict]`: a list of dictionaries containing the predicted label, associated probability, word ids, and page number.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.
Example:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
[{'score': 0.42515629529953003, 'answer': 'us-001', 'start': 16, 'end': 16}]
```
"""
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
response = self.post(json=payload, model=model, task="document-question-answering")
return _bytes_to_list(response)

def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
"""
Generate embeddings for a given text.
Expand Down
42 changes: 42 additions & 0 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,48 @@ async def visual_question_answering(
response = await self.post(json=payload, model=model, task="visual-question-answering")
return _bytes_to_list(response)

async def document_question_answering(
self,
image: ContentT,
question: str,
*,
model: Optional[str] = None,
) -> List[QuestionAnsweringOutput]:
"""
Answer questions on document images.
Args:
image (`Union[str, Path, bytes, BinaryIO]`):
The input image for the context. It can be raw bytes, an image file, or a URL to an online image.
question (`str`):
Question to be answered.
model (`str`, *optional*):
The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used.
Defaults to None.
Returns:
`List[Dict]`: a list of dictionaries containing the predicted label, associated probability, word ids, and page number.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`aiohttp.ClientResponseError`:
If the request fails with an HTTP error status code other than HTTP 503.
Example:
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()
>>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?")
[{'score': 0.42515629529953003, 'answer': 'us-001', 'start': 16, 'end': 16}]
```
"""
payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)}
response = await self.post(json=payload, model=model, task="document-question-answering")
return _bytes_to_list(response)

async def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
"""
Generate embeddings for a given text.
Expand Down
6 changes: 6 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"audio-to-audio": "speechbrain/sepformer-wham",
"automatic-speech-recognition": "facebook/wav2vec2-base-960h",
"conversational": "facebook/blenderbot-400M-distill",
"document-question-answering": "naver-clova-ix/donut-base-finetuned-docvqa",
"feature-extraction": "facebook/bart-base",
"image-classification": "google/vit-base-patch16-224",
"image-segmentation": "facebook/detr-resnet-50-panoptic",
Expand All @@ -56,6 +57,7 @@ class InferenceClientTest(unittest.TestCase):
def setUpClass(cls) -> None:
super().setUpClass()
cls.image_file = hf_hub_download(repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png")
cls.document_file = hf_hub_download(repo_id="impira/docquery", repo_type="space", filename="contract.jpeg")
cls.audio_file = hf_hub_download(repo_id="Narsil/image_dummy", repo_type="dataset", filename="sample1.flac")


Expand Down Expand Up @@ -124,6 +126,10 @@ def test_conversational(self) -> None:
},
)

def test_document_question_answering(self) -> None:
output = self.client.document_question_answering(self.document_file, "What is the purchase amount?")
self.assertEqual(output, [{"answer": "$1,000,000,000"}])

def test_feature_extraction(self) -> None:
embedding = self.client.feature_extraction("Hi, who are you?")
self.assertIsInstance(embedding, np.ndarray)
Expand Down

0 comments on commit 816fc20

Please sign in to comment.