Skip to content

Commit

Permalink
Pass value of HF_TOKEN environment variable when loading models wit…
Browse files Browse the repository at this point in the history
…h `gr.load` (#10092)

* verbose

* add changeset

* changes

* maybe fix

* add changeset

* changes

* add changeset

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Dec 13, 2024
1 parent 424365b commit 20b9d72
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .changeset/tall-maps-pump.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Pass value of `HF_TOKEN` environment variable when loading models with `gr.load`
6 changes: 2 additions & 4 deletions gradio/components/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,15 @@ def _preprocess_messages_tuples(
def preprocess(
self,
payload: ChatbotDataTuples | ChatbotDataMessages | None,
) -> (
list[list[str | tuple[str] | tuple[str, str] | None]] | list[MessageDict] | None
):
) -> list[list[str | tuple[str] | tuple[str, str] | None]] | list[MessageDict]:
"""
Parameters:
payload: data as a ChatbotData object
Returns:
If type is 'tuples', passes the messages in the chatbot as a `list[list[str | None | tuple]]`, i.e. a list of lists. The inner list has 2 elements: the user message and the response message. Each message can be (1) a string in valid Markdown, (2) a tuple if there are displayed files: (a filepath or URL to a file, [optional string alt text]), or (3) None, if there is no message displayed. If type is 'messages', passes the value as a list of dictionaries with 'role' and 'content' keys. The `content` key's value supports everything the `tuples` format supports.
"""
if payload is None:
return payload
return []
if self.type == "tuples":
if not isinstance(payload, ChatbotDataTuples):
raise Error("Data incompatible with the tuples format")
Expand Down
8 changes: 7 additions & 1 deletion gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def load(
Parameters:
name: the name of the model (e.g. "google/vit-base-patch16-224") or Space (e.g. "flax-community/spanish-gpt2"). This is the first parameter passed into the `src` function. Can also be formatted as {src}/{repo name} (e.g. "models/google/vit-base-patch16-224") if `src` is not provided.
src: function that accepts a string model `name` and a string or None `token` and returns a Gradio app. Alternatively, this parameter takes one of two strings for convenience: "models" (for loading a Hugging Face model through the Inference API) or "spaces" (for loading a Hugging Face Space). If None, uses the prefix of the `name` parameter to determine `src`.
token: optional token that is passed as the second parameter to the `src` function. For Hugging Face repos, uses the local HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find HF tokens here: https://huggingface.co/settings/tokens.
token: optional token that is passed as the second parameter to the `src` function. If not explicitly provided, will use the HF_TOKEN environment variable or fallback to the locally-saved HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find your HF tokens here: https://huggingface.co/settings/tokens.
accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space.
kwargs: additional keyword parameters to pass into the `src` function. If `src` is "models" or "Spaces", these parameters are passed into the `gr.Interface` or `gr.ChatInterface` constructor.
Returns:
Expand Down Expand Up @@ -78,6 +78,12 @@ def load(
raise ValueError(
"The `src` parameter must be one of 'huggingface', 'models', 'spaces', or a function that accepts a model name (and optionally, a token), and returns a Gradio app."
)
if (
token is None
and src in ["models", "huggingface"]
and os.environ.get("HF_TOKEN") is not None
):
token = os.environ.get("HF_TOKEN")

if not accept_token:
if isinstance(src, Callable):
Expand Down

0 comments on commit 20b9d72

Please sign in to comment.