Skip to content

Commit

Permalink
Add type hints for providers (#2788)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored Jan 27, 2025
1 parent a259e88 commit d2cab33
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
ZeroShotClassificationOutputElement,
ZeroShotImageClassificationOutputElement,
)
from huggingface_hub.inference._providers import HFInferenceTask, get_provider_helper
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method

Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(
self,
model: Optional[str] = None,
*,
provider: Optional[str] = None,
provider: Optional[PROVIDER_T] = None,
token: Optional[str] = None,
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
ZeroShotClassificationOutputElement,
ZeroShotImageClassificationOutputElement,
)
from huggingface_hub.inference._providers import HFInferenceTask, get_provider_helper
from huggingface_hub.inference._providers import PROVIDER_T, HFInferenceTask, get_provider_helper
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
from huggingface_hub.utils._deprecation import _deprecate_arguments, _deprecate_method

Expand Down Expand Up @@ -154,7 +154,7 @@ def __init__(
self,
model: Optional[str] = None,
*,
provider: Optional[str] = None,
provider: Optional[PROVIDER_T] = None,
token: Optional[str] = None,
timeout: Optional[float] = None,
headers: Optional[Dict[str, str]] = None,
Expand Down
14 changes: 11 additions & 3 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Literal

from .._common import TaskProviderHelper
from .fal_ai import FalAIAutomaticSpeechRecognitionTask, FalAITextToImageTask, FalAITextToVideoTask
Expand All @@ -8,7 +8,15 @@
from .together import TogetherTextGenerationTask, TogetherTextToImageTask


PROVIDERS: Dict[str, Dict[str, TaskProviderHelper]] = {
PROVIDER_T = Literal[
"fal-ai",
"hf-inference",
"replicate",
"sambanova",
"together",
]

PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = {
"fal-ai": {
"text-to-image": FalAITextToImageTask(),
"automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(),
Expand Down Expand Up @@ -58,7 +66,7 @@
}


def get_provider_helper(provider: str, task: str) -> TaskProviderHelper:
def get_provider_helper(provider: PROVIDER_T, task: str) -> TaskProviderHelper:
"""Get provider helper instance by name and task.
Args:
Expand Down

0 comments on commit d2cab33

Please sign in to comment.