Skip to content

Commit

Permalink
Add tabular classification to inference client (#1614)
Browse files Browse the repository at this point in the history
* Add tabular classification to inference client

* skip test

* docs

* skip

---------

Co-authored-by: Lucain Pouget <[email protected]>
  • Loading branch information
martinbrose and Wauplin authored Sep 7, 2023
1 parent 816fc20 commit a9d6552
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ has a simple API that supports the most common tasks. Here is a list of the curr
| | [Token Classification](https://huggingface.co/tasks/token-classification) || [`~InferenceClient.token_classification`] |
| | [Translation](https://huggingface.co/tasks/translation) || [`~InferenceClient.translation`] |
| | [Zero Shot Classification](https://huggingface.co/tasks/zero-shot-image-classification) | | |
| Tabular | [Tabular Classification](https://huggingface.co/tasks/tabular-classification) | | |
| Tabular | [Tabular Classification](https://huggingface.co/tasks/tabular-classification) | | [`~InferenceClient.tabular_classification`] |
| | [Tabular Regression](https://huggingface.co/tasks/tabular-regression) | | |


Expand Down
48 changes: 46 additions & 2 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,9 +980,53 @@ def table_question_answering(
)
return _bytes_to_dict(response) # type: ignore

def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]:
"""
Classifying a target category (a group) based on a set of attributes.
Args:
table (`Dict[str, Any]`):
Set of attributes to classify.
model (`str`):
The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint.
Returns:
`List`: a list of labels, one per row in the initial table.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`HTTPError`:
If the request fails with an HTTP error status code other than HTTP 503.
Example:
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> table = {
... "fixed_acidity": ["7.4", "7.8", "10.3"],
... "volatile_acidity": ["0.7", "0.88", "0.32"],
... "citric_acid": ["0", "0", "0.45"],
... "residual_sugar": ["1.9", "2.6", "6.4"],
... "chlorides": ["0.076", "0.098", "0.073"],
... "free_sulfur_dioxide": ["11", "25", "5"],
... "total_sulfur_dioxide": ["34", "67", "13"],
... "density": ["0.9978", "0.9968", "0.9976"],
... "pH": ["3.51", "3.2", "3.23"],
... "sulphates": ["0.56", "0.68", "0.82"],
... "alcohol": ["9.4", "9.8", "12.6"],
... }
>>> client.tabular_classification(table=table, model="julien-c/wine-quality")
["5", "5", "5"]
```
"""
response = self.post(json={"table": table}, model=model, task="tabular-classification")
return _bytes_to_list(response)

def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
"""
Perform sentiment-analysis on the given text.
Perform text classification (e.g. sentiment-analysis) on the given text.
Args:
text (`str`):
Expand All @@ -1005,7 +1049,7 @@ def text_classification(self, text: str, *, model: Optional[str] = None) -> List
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> output = client.text_classification("I like you")
>>> client.text_classification("I like you")
[{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]
```
"""
Expand Down
49 changes: 47 additions & 2 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,54 @@ async def table_question_answering(
)
return _bytes_to_dict(response) # type: ignore

async def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]:
"""
Classifying a target category (a group) based on a set of attributes.
Args:
table (`Dict[str, Any]`):
Set of attributes to classify.
model (`str`):
The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to
a deployed Inference Endpoint.
Returns:
`List`: a list of labels, one per row in the initial table.
Raises:
[`InferenceTimeoutError`]:
If the model is unavailable or the request times out.
`aiohttp.ClientResponseError`:
If the request fails with an HTTP error status code other than HTTP 503.
Example:
```py
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()
>>> table = {
... "fixed_acidity": ["7.4", "7.8", "10.3"],
... "volatile_acidity": ["0.7", "0.88", "0.32"],
... "citric_acid": ["0", "0", "0.45"],
... "residual_sugar": ["1.9", "2.6", "6.4"],
... "chlorides": ["0.076", "0.098", "0.073"],
... "free_sulfur_dioxide": ["11", "25", "5"],
... "total_sulfur_dioxide": ["34", "67", "13"],
... "density": ["0.9978", "0.9968", "0.9976"],
... "pH": ["3.51", "3.2", "3.23"],
... "sulphates": ["0.56", "0.68", "0.82"],
... "alcohol": ["9.4", "9.8", "12.6"],
... }
>>> await client.tabular_classification(table=table, model="julien-c/wine-quality")
["5", "5", "5"]
```
"""
response = await self.post(json={"table": table}, model=model, task="tabular-classification")
return _bytes_to_list(response)

async def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]:
"""
Perform sentiment-analysis on the given text.
Perform text classification (e.g. sentiment-analysis) on the given text.
Args:
text (`str`):
Expand All @@ -1018,7 +1063,7 @@ async def text_classification(self, text: str, *, model: Optional[str] = None) -
# Must be run in an async context
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()
>>> output = await client.text_classification("I like you")
>>> await client.text_classification("I like you")
[{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]
```
"""
Expand Down
19 changes: 19 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"sentence-similarity": "sentence-transformers/all-MiniLM-L6-v2",
"summarization": "sshleifer/distilbart-cnn-12-6",
"table-question-answering": "google/tapas-base-finetuned-wtq",
"tabular-classification": "julien-c/wine-quality",
"text-classification": "distilbert-base-uncased-finetuned-sst-2-english",
"text-to-image": "CompVis/stable-diffusion-v1-4",
"text-to-speech": "espnet/kan-bayashi_ljspeech_vits",
Expand Down Expand Up @@ -231,6 +232,24 @@ def test_summarization(self) -> None:
" surpassed the Washington Monument to become the tallest man-made structure in the world.",
)

@pytest.mark.skip(reason="This model is not available on the free InferenceAPI")
def test_tabular_classification(self) -> None:
table = {
"fixed_acidity": ["7.4", "7.8", "10.3"],
"volatile_acidity": ["0.7", "0.88", "0.32"],
"citric_acid": ["0", "0", "0.45"],
"residual_sugar": ["1.9", "2.6", "6.4"],
"chlorides": ["0.076", "0.098", "0.073"],
"free_sulfur_dioxide": ["11", "25", "5"],
"total_sulfur_dioxide": ["34", "67", "13"],
"density": ["0.9978", "0.9968", "0.9976"],
"pH": ["3.51", "3.2", "3.23"],
"sulphates": ["0.56", "0.68", "0.82"],
"alcohol": ["9.4", "9.8", "12.6"],
}
output = self.client.tabular_classification(table=table)
self.assertEqual(output, ["5", "5", "5"])

def test_table_question_answering(self) -> None:
table = {
"Repository": ["Transformers", "Datasets", "Tokenizers"],
Expand Down

0 comments on commit a9d6552

Please sign in to comment.