Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gr.LoginButton for gr.load() #10577

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions .changeset/four-wasps-hug.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Support `gr.LoginButton` for `gr.load()`
85 changes: 52 additions & 33 deletions gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from gradio_client.utils import encode_url_or_file_to_base64
from packaging import version

import gradio
import gradio as gr
from gradio import components, external_utils, utils
from gradio.components.multimodal_textbox import MultimodalValue
from gradio.context import Context
Expand All @@ -36,6 +36,7 @@
from gradio.blocks import Blocks
from gradio.chat_interface import ChatInterface
from gradio.components.chatbot import MessageDict
from gradio.components.login_button import LoginButton
from gradio.interface import Interface


Expand All @@ -47,7 +48,7 @@ def load(
| None = None,
token: str | None = None,
hf_token: str | None = None,
accept_token: bool = False,
accept_token: bool | LoginButton = False,
provider: PROVIDER_T | None = None,
**kwargs,
) -> Blocks:
Expand All @@ -57,7 +58,7 @@ def load(
name: the name of the model (e.g. "google/vit-base-patch16-224") or Space (e.g. "flax-community/spanish-gpt2"). This is the first parameter passed into the `src` function. Can also be formatted as {src}/{repo name} (e.g. "models/google/vit-base-patch16-224") if `src` is not provided.
src: function that accepts a string model `name` and a string or None `token` and returns a Gradio app. Alternatively, this parameter takes one of two strings for convenience: "models" (for loading a Hugging Face model through the Inference API) or "spaces" (for loading a Hugging Face Space). If None, uses the prefix of the `name` parameter to determine `src`.
token: optional token that is passed as the second parameter to the `src` function. If not explicitly provided, will use the HF_TOKEN environment variable or fallback to the locally-saved HF token when loading models but not Spaces (when loading Spaces, only provide a token if you are loading a trusted private Space as the token can be read by the Space you are loading). Find your HF tokens here: https://huggingface.co/settings/tokens.
accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space.
accept_token: if True, a Textbox component is first rendered to allow the user to provide a token, which will be used instead of the `token` parameter when calling the loaded model or Space. Can also provide an instance of a gr.LoginButton in the same Blocks scope, which allows the user to login with a Hugging Face account whose token will be used instead of the `token` parameter when calling the loaded model or Space.
kwargs: additional keyword parameters to pass into the `src` function. If `src` is "models" or "Spaces", these parameters are passed into the `gr.Interface` or `gr.ChatInterface` constructor.
provider: the name of the third-party (non-Hugging Face) providers to use for model inference (e.g. "replicate", "sambanova", "fal-ai", etc). Should be one of the providers supported by `huggingface_hub.InferenceClient`. This parameter is only used when `src` is "models"
Returns:
Expand Down Expand Up @@ -93,15 +94,31 @@ def load(
):
token = os.environ.get("HF_TOKEN")

if isinstance(src, Callable):
return src(name, token, **kwargs)

if not accept_token:
if isinstance(src, Callable):
return src(name, token, **kwargs)
return load_blocks_from_huggingface(
name=name, src=src, hf_token=token, provider=provider, **kwargs
)
else:
import gradio as gr
elif isinstance(accept_token, gr.LoginButton):
with gr.Blocks(fill_height=True) as demo:
if not accept_token.is_rendered:
accept_token.render()

@gr.render(triggers=[demo.load])
def create_blocks(oauth_token: gr.OAuthToken | None):
token_value = None if oauth_token is None else oauth_token.token
return load_blocks_from_huggingface(
name=name,
src=src,
hf_token=token_value,
provider=provider,
**kwargs,
)

return demo
else:
with gr.Blocks(fill_height=True) as demo:
with gr.Accordion("Enter your token and press enter") as accordion:
textbox = gr.Textbox(
Expand Down Expand Up @@ -140,10 +157,12 @@ def load_token(token_value):

@gr.render(inputs=[textbox], triggers=[textbox.submit])
def create(token_value):
if isinstance(src, Callable):
return src(name, token_value, **kwargs)
return load_blocks_from_huggingface(
name=name, src=src, hf_token=token_value, **kwargs
name=name,
src=src,
hf_token=token_value,
provider=provider,
**kwargs,
)

return demo
Expand All @@ -152,43 +171,40 @@ def create(token_value):
def load_blocks_from_huggingface(
name: str,
src: str,
hf_token: str | Literal[False] | None = None,
hf_token: str | None = None,
alias: str | None = None,
provider: PROVIDER_T | None = None,
**kwargs,
) -> Blocks:
"""Creates and returns a Blocks instance from a Hugging Face model or Space repo."""
factory_methods: dict[str, Callable] = {
# for each repo type, we have a method that returns the Interface given the model name & optionally an hf_token
"huggingface": from_model,
"models": from_model,
"spaces": from_spaces,
}
if hf_token is not None and hf_token is not False:
if hf_token is not None:
if Context.hf_token is not None and Context.hf_token != hf_token:
warnings.warn(
"""You are loading a model/Space with a different access token than the one you used to load a previous model/Space. This is not recommended, as it may cause unexpected behavior."""
)
Context.hf_token = hf_token

if src == "spaces" and hf_token is None:
hf_token = False # Since Spaces can read the token, we don't want to pass it in unless the user explicitly provides it
blocks: gradio.Blocks = factory_methods[src](
name, hf_token=hf_token, alias=alias, provider=provider, **kwargs
)
if src == "spaces":
# Spaces can read the token, so we don't want to pass it in unless the user explicitly provides it
token = False if hf_token is None else hf_token
blocks = from_spaces(
name, hf_token=token, alias=alias, provider=provider, **kwargs
)
else:
blocks = from_model(
name, hf_token=hf_token, alias=alias, provider=provider, **kwargs
)
return blocks


def from_model(
model_name: str,
hf_token: str | Literal[False] | None,
hf_token: str | None,
alias: str | None,
provider: PROVIDER_T | None = None,
**kwargs,
) -> Blocks:
headers = {"X-Wait-For-Model": "true"}
if hf_token is False:
hf_token = None
client = huggingface_hub.InferenceClient(
model=model_name, headers=headers, token=hf_token, provider=provider
)
Expand Down Expand Up @@ -288,6 +304,7 @@ def custom_post_binary(data):
"The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
]
]
postprocess = lambda x: x.summary_text # noqa: E731
fn = client.summarization
# Example: distilbert-base-uncased-finetuned-sst-2-english
elif p == "text-classification":
Expand Down Expand Up @@ -451,6 +468,7 @@ def query_huggingface_inference_endpoints(*data):
data = preprocess(*data)
try:
data = fn(*data) # type: ignore
print("data after fn", data)
except huggingface_hub.utils.HfHubHTTPError as e: # type: ignore
if "429" in str(e):
raise TooManyRequestsError() from e
Expand All @@ -466,10 +484,11 @@ def query_huggingface_inference_endpoints(*data):
"outputs": outputs,
"title": model_name,
"examples": examples,
"cache_examples": False,
}

kwargs = dict(interface_info, **kwargs)
interface = gradio.Interface(**kwargs)
interface = gr.Interface(**kwargs)
return interface


Expand Down Expand Up @@ -565,7 +584,7 @@ def from_spaces_blocks(space: str, hf_token: str | None | Literal[False]) -> Blo
predict_fns.append(endpoint.make_end_to_end_fn(helper))
else:
predict_fns.append(None)
return gradio.Blocks.from_config(client.config, predict_fns, client.src) # type: ignore
return gr.Blocks.from_config(client.config, predict_fns, client.src) # type: ignore


def from_spaces_interface(
Expand Down Expand Up @@ -610,7 +629,7 @@ def fn(*data):

kwargs = dict(config, **kwargs)
kwargs["_api_mode"] = True
interface = gradio.Interface(**kwargs)
interface = gr.Interface(**kwargs)
return interface


Expand Down Expand Up @@ -784,7 +803,7 @@ def load_chat(
raise ImportError(
"To use OpenAI API Client, you must install the `openai` package. You can install it with `pip install openai`."
) from e
from gradio.chat_interface import ChatInterface
from gr.chat_interface import ChatInterface

client = OpenAI(api_key=token, base_url=base_url)
start_message = (
Expand All @@ -796,7 +815,7 @@ def open_api(message: str | MultimodalValue, history: list | None) -> str | None
history = history or start_message
if len(history) > 0 and isinstance(history[0], (list, tuple)):
history = ChatInterface._tuples_to_messages(history)
conversation = format_conversation(history, message)
conversation = format_conversation(history, message) # type: ignore
return (
client.chat.completions.create(
model=model,
Expand All @@ -812,7 +831,7 @@ def open_api_stream(
history = history or start_message
if len(history) > 0 and isinstance(history[0], (list, tuple)):
history = ChatInterface._tuples_to_messages(history)
conversation = format_conversation(history, message)
conversation = format_conversation(history, message) # type: ignore
stream = client.chat.completions.create(
model=model,
messages=conversation, # type: ignore
Expand All @@ -839,7 +858,7 @@ def open_api_stream(
open_api_stream if streaming else open_api,
type="messages",
multimodal=bool(file_types),
textbox=gradio.MultimodalTextbox(file_types=supported_extensions)
textbox=gr.MultimodalTextbox(file_types=supported_extensions)
if file_types
else None,
**kwargs,
Expand Down
Loading