diff --git a/docs/source/en/guides/inference.md b/docs/source/en/guides/inference.md index d930fc0ffc..20d670088d 100644 --- a/docs/source/en/guides/inference.md +++ b/docs/source/en/guides/inference.md @@ -144,7 +144,7 @@ has a simple API that supports the most common tasks. Here is a list of the curr | | [Token Classification](https://huggingface.co/tasks/token-classification) | ✅ | [`~InferenceClient.token_classification`] | | | [Translation](https://huggingface.co/tasks/translation) | ✅ | [`~InferenceClient.translation`] | | | [Zero Shot Classification](https://huggingface.co/tasks/zero-shot-image-classification) | | | -| Tabular | [Tabular Classification](https://huggingface.co/tasks/tabular-classification) | | | +| Tabular | [Tabular Classification](https://huggingface.co/tasks/tabular-classification) | ✅ | [`~InferenceClient.tabular_classification`] | | | [Tabular Regression](https://huggingface.co/tasks/tabular-regression) | | | diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index db7fd6d3dd..d5abc1b412 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -980,9 +980,53 @@ def table_question_answering( ) return _bytes_to_dict(response) # type: ignore + def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]: + """ + Classifying a target category (a group) based on a set of attributes. + + Args: + table (`Dict[str, Any]`): + Set of attributes to classify. + model (`str`): + The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. + + Returns: + `List`: a list of labels, one per row in the initial table. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> table = { + ... "fixed_acidity": ["7.4", "7.8", "10.3"], + ... "volatile_acidity": ["0.7", "0.88", "0.32"], + ... "citric_acid": ["0", "0", "0.45"], + ... "residual_sugar": ["1.9", "2.6", "6.4"], + ... "chlorides": ["0.076", "0.098", "0.073"], + ... "free_sulfur_dioxide": ["11", "25", "5"], + ... "total_sulfur_dioxide": ["34", "67", "13"], + ... "density": ["0.9978", "0.9968", "0.9976"], + ... "pH": ["3.51", "3.2", "3.23"], + ... "sulphates": ["0.56", "0.68", "0.82"], + ... "alcohol": ["9.4", "9.8", "12.6"], + ... } + >>> client.tabular_classification(table=table, model="julien-c/wine-quality") + ["5", "5", "5"] + ``` + """ + response = self.post(json={"table": table}, model=model, task="tabular-classification") + return _bytes_to_list(response) + def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]: """ - Perform sentiment-analysis on the given text. + Perform text classification (e.g. sentiment-analysis) on the given text. Args: text (`str`): @@ -1005,7 +1049,7 @@ def text_classification(self, text: str, *, model: Optional[str] = None) -> List ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() - >>> output = client.text_classification("I like you") + >>> client.text_classification("I like you") [{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}] ``` """ diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 8498a4a930..80fe4c0a91 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -992,9 +992,54 @@ async def table_question_answering( ) return _bytes_to_dict(response) # type: ignore + async def tabular_classification(self, table: Dict[str, Any], *, model: str) -> List[str]: + """ + Classifying a target category (a group) based on a set of attributes. + + Args: + table (`Dict[str, Any]`): + Set of attributes to classify. + model (`str`): + The model to use for the tabular-classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. + + Returns: + `List`: a list of labels, one per row in the initial table. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> table = { + ... "fixed_acidity": ["7.4", "7.8", "10.3"], + ... "volatile_acidity": ["0.7", "0.88", "0.32"], + ... "citric_acid": ["0", "0", "0.45"], + ... "residual_sugar": ["1.9", "2.6", "6.4"], + ... "chlorides": ["0.076", "0.098", "0.073"], + ... "free_sulfur_dioxide": ["11", "25", "5"], + ... "total_sulfur_dioxide": ["34", "67", "13"], + ... "density": ["0.9978", "0.9968", "0.9976"], + ... "pH": ["3.51", "3.2", "3.23"], + ... "sulphates": ["0.56", "0.68", "0.82"], + ... "alcohol": ["9.4", "9.8", "12.6"], + ... } + >>> await client.tabular_classification(table=table, model="julien-c/wine-quality") + ["5", "5", "5"] + ``` + """ + response = await self.post(json={"table": table}, model=model, task="tabular-classification") + return _bytes_to_list(response) + async def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]: """ - Perform sentiment-analysis on the given text. + Perform text classification (e.g. sentiment-analysis) on the given text. Args: text (`str`): @@ -1018,7 +1063,7 @@ async def text_classification(self, text: str, *, model: Optional[str] = None) - # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() - >>> output = await client.text_classification("I like you") + >>> await client.text_classification("I like you") [{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}] ``` """ diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 247548743b..9c05e90d2c 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -41,6 +41,7 @@ "sentence-similarity": "sentence-transformers/all-MiniLM-L6-v2", "summarization": "sshleifer/distilbart-cnn-12-6", "table-question-answering": "google/tapas-base-finetuned-wtq", + "tabular-classification": "julien-c/wine-quality", "text-classification": "distilbert-base-uncased-finetuned-sst-2-english", "text-to-image": "CompVis/stable-diffusion-v1-4", "text-to-speech": "espnet/kan-bayashi_ljspeech_vits", @@ -231,6 +232,24 @@ def test_summarization(self) -> None: " surpassed the Washington Monument to become the tallest man-made structure in the world.", ) + @pytest.mark.skip(reason="This model is not available on the free InferenceAPI") + def test_tabular_classification(self) -> None: + table = { + "fixed_acidity": ["7.4", "7.8", "10.3"], + "volatile_acidity": ["0.7", "0.88", "0.32"], + "citric_acid": ["0", "0", "0.45"], + "residual_sugar": ["1.9", "2.6", "6.4"], + "chlorides": ["0.076", "0.098", "0.073"], + "free_sulfur_dioxide": ["11", "25", "5"], + "total_sulfur_dioxide": ["34", "67", "13"], + "density": ["0.9978", "0.9968", "0.9976"], + "pH": ["3.51", "3.2", "3.23"], + "sulphates": ["0.56", "0.68", "0.82"], + "alcohol": ["9.4", "9.8", "12.6"], + } + output = self.client.tabular_classification(table=table) + self.assertEqual(output, ["5", "5", "5"]) + def test_table_question_answering(self) -> None: table = { "Repository": ["Transformers", "Datasets", "Tokenizers"],