diff --git a/setup.py b/setup.py index dbce1a1a32..f5b87210af 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ def get_version() -> str: extras["testing"] = [ "pytest", + "datasets", ] extras["quality"] = [ diff --git a/src/huggingface_hub/inference_api.py b/src/huggingface_hub/inference_api.py index 004fa31f7b..e62734221d 100644 --- a/src/huggingface_hub/inference_api.py +++ b/src/huggingface_hub/inference_api.py @@ -30,12 +30,13 @@ "text-to-speech", "automatic-speech-recognition", "audio-to-audio", - "audio-source-separation", + "audio-classification", "voice-activity-detection", # Computer vision "image-classification", "object-detection", "image-segmentation", + "text-to-image", # Others "structured-data-classification", ] @@ -122,20 +123,23 @@ def __repr__(self): def __call__( self, - inputs: Union[str, Dict, List[str], List[List[str]]], + inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None, params: Optional[Dict] = None, + data: Optional[bytes] = None, ): payload = { - "inputs": inputs, "options": self.options, } + if inputs: + payload["inputs"] = inputs + if params: payload["parameters"] = params # TODO: Decide if we should raise an error instead of # returning the json. response = requests.post( - self.api_url, headers=self.headers, json=payload + self.api_url, headers=self.headers, json=payload, data=data ).json() return response diff --git a/tests/test_inference_api.py b/tests/test_inference_api.py index f12a6ed4ff..0a270535f1 100644 --- a/tests/test_inference_api.py +++ b/tests/test_inference_api.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. - import unittest +import datasets + from huggingface_hub.inference_api import InferenceApi from .testing_utils import with_production_testing class InferenceApiTest(unittest.TestCase): + def read(self, filename: str) -> bytes: + with open(filename, "rb") as f: + bpayload = f.read() + return bpayload + @with_production_testing def test_simple_inference(self): api = InferenceApi("bert-base-uncased") @@ -55,6 +61,29 @@ def test_inference_with_dict_inputs(self): self.assertTrue("score" in result) self.assertTrue("answer" in result) + @with_production_testing + def test_inference_with_audio(self): + api = InferenceApi("facebook/wav2vec2-large-960h-lv60-self") + dataset = datasets.load_dataset( + "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation" + ) + data = self.read(dataset[0]["file"]) + result = api(data=data) + self.assertIsInstance(result, dict) + self.assertTrue("text" in result) + + @with_production_testing + def test_inference_with_image(self): + api = InferenceApi("google/vit-base-patch16-224") + dataset = datasets.load_dataset("Narsil/image_dummy", "image", split="test") + data = self.read(dataset[0]["file"]) + result = api(data=data) + self.assertIsInstance(result, list) + for classification in result: + self.assertIsInstance(classification, dict) + self.assertTrue("score" in classification) + self.assertTrue("label" in classification) + @with_production_testing def test_inference_overriding_task(self): api = InferenceApi(