diff --git a/src/huggingface_hub/inference/_common.py b/src/huggingface_hub/inference/_common.py index c312449187..5973083a71 100644 --- a/src/huggingface_hub/inference/_common.py +++ b/src/huggingface_hub/inference/_common.py @@ -48,9 +48,7 @@ is_numpy_available, is_pillow_available, ) -from ._text_generation import ( - TextGenerationStreamResponse, -) +from ._text_generation import TextGenerationStreamResponse, _parse_text_generation_error if TYPE_CHECKING: @@ -275,7 +273,10 @@ def _stream_text_generation_response( if payload.startswith("data:"): # Decode payload json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) - # Parse payload + # Either an error as being returned + if json_payload.get("error") is not None: + raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) + # Or parse token payload output = TextGenerationStreamResponse(**json_payload) yield output.token.text if not details else output @@ -295,7 +296,10 @@ async def _async_stream_text_generation_response( if payload.startswith("data:"): # Decode payload json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) - # Parse payload + # Either an error as being returned + if json_payload.get("error") is not None: + raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) + # Or parse token payload output = TextGenerationStreamResponse(**json_payload) yield output.token.text if not details else output diff --git a/src/huggingface_hub/inference/_text_generation.py b/src/huggingface_hub/inference/_text_generation.py index aba07834bb..a67c127b87 100644 --- a/src/huggingface_hub/inference/_text_generation.py +++ b/src/huggingface_hub/inference/_text_generation.py @@ -447,6 +447,10 @@ class IncompleteGenerationError(TextGenerationError): pass +class UnknownError(TextGenerationError): + pass + + def raise_text_generation_error(http_error: HTTPError) -> NoReturn: """ Try to parse text-generation-inference error message and raise HTTPError in any case. @@ -460,21 +464,27 @@ def raise_text_generation_error(http_error: HTTPError) -> NoReturn: try: # Hacky way to retrieve payload in case of aiohttp error payload = getattr(http_error, "response_error_payload", None) or http_error.response.json() - message = payload.get("error") + error = payload.get("error") error_type = payload.get("error_type") except Exception: # no payload raise http_error # If error_type => more information than `hf_raise_for_status` if error_type is not None: - if error_type == "generation": - raise GenerationError(message) from http_error # type: ignore - if error_type == "incomplete_generation": - raise IncompleteGenerationError(message) from http_error # type: ignore - if error_type == "overloaded": - raise OverloadedError(message) from http_error # type: ignore - if error_type == "validation": - raise ValidationError(message) from http_error # type: ignore + exception = _parse_text_generation_error(error, error_type) + raise exception from http_error # Otherwise, fallback to default error raise http_error + + +def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError: + if error_type == "generation": + return GenerationError(error) # type: ignore + if error_type == "incomplete_generation": + return IncompleteGenerationError(error) # type: ignore + if error_type == "overloaded": + return OverloadedError(error) # type: ignore + if error_type == "validation": + return ValidationError(error) # type: ignore + return UnknownError(error) # type: ignore