From 20b9d72ebb1b962cb34a657cbdad15d003931c6e Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 12 Dec 2024 23:53:14 -0800 Subject: [PATCH] Pass value of `HF_TOKEN` environment variable when loading models with `gr.load` (#10092) * verbose * add changeset * changes * maybe fix * add changeset * changes * add changeset --------- Co-authored-by: gradio-pr-bot --- .changeset/tall-maps-pump.md | 5 +++++ gradio/components/chatbot.py | 6 ++---- gradio/external.py | 8 +++++++- 3 files changed, 14 insertions(+), 5 deletions(-) create mode 100644 .changeset/tall-maps-pump.md diff --git a/.changeset/tall-maps-pump.md b/.changeset/tall-maps-pump.md new file mode 100644 index 0000000000000..e3028516ad542 --- /dev/null +++ b/.changeset/tall-maps-pump.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Pass value of `HF_TOKEN` environment variable when loading models with `gr.load` diff --git a/gradio/components/chatbot.py b/gradio/components/chatbot.py index d224e6243aeb1..267073a503105 100644 --- a/gradio/components/chatbot.py +++ b/gradio/components/chatbot.py @@ -422,9 +422,7 @@ def _preprocess_messages_tuples( def preprocess( self, payload: ChatbotDataTuples | ChatbotDataMessages | None, - ) -> ( - list[list[str | tuple[str] | tuple[str, str] | None]] | list[MessageDict] | None - ): + ) -> list[list[str | tuple[str] | tuple[str, str] | None]] | list[MessageDict]: """ Parameters: payload: data as a ChatbotData object @@ -432,7 +430,7 @@ def preprocess( If type is 'tuples', passes the messages in the chatbot as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. Each message can be (1) a string in valid Markdown, (2) a tuple if there are displayed files: (a filepath or URL to a file, [optional string alt text]), or (3) None, if there is no message displayed. If type is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports. """ if payload is None: - return payload + return [] if self.type == "tuples": if not isinstance(payload, ChatbotDataTuples): raise Error("Data incompatible with the tuples format") diff --git a/gradio/external.py b/gradio/external.py index d4a73a41f0dfd..7cc01cc68fd15 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -49,7 +49,7 @@ def load( Parameters: name: the name of the model (e.g. "google/vit-base-patch16-224") or Space (e.g. "flax-community/spanish-gpt2"). This is the first parameter passed into the `src` function. Can also be formatted as {src}/{repo name} (e.g. "models/google/vit-base-patch16-224") if `src` is not provided. src: function that accepts a string model `name` and a string or None `token` and returns a Gradio app. Alternatively, this parameter takes one of two strings for convenience: "models" (for loading a Hugging Face model through the Inference API) or "spaces" (for loading a Hugging Face Space). If None, uses the prefix of the `name` parameter to determine `src`. - token: optional token that is passed as the second parameter to the `src` function. For Hugging Face repos, uses the local HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find HF tokens here: https://huggingface.co/settings/tokens. + token: optional token that is passed as the second parameter to the `src` function. If not explicitly provided, will use the HF_TOKEN environment variable or fallback to the locally-saved HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find your HF tokens here: https://huggingface.co/settings/tokens. accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space. kwargs: additional keyword parameters to pass into the `src` function. If `src` is "models" or "Spaces", these parameters are passed into the `gr.Interface` or `gr.ChatInterface` constructor. Returns: @@ -78,6 +78,12 @@ def load( raise ValueError( "The `src` parameter must be one of 'huggingface', 'models', 'spaces', or a function that accepts a model name (and optionally, a token), and returns a Gradio app." ) + if ( + token is None + and src in ["models", "huggingface"] + and os.environ.get("HF_TOKEN") is not None + ): + token = os.environ.get("HF_TOKEN") if not accept_token: if isinstance(src, Callable):