diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index de564fc48a..bd2f7643bc 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -18,6 +18,8 @@ title: Repository - local: guides/search title: Search + - local: guides/hf_file_system + title: HfFileSystem - local: guides/inference title: Inference - local: guides/community @@ -30,6 +32,8 @@ title: Manage your Space - local: guides/integrations title: Integrate a library + - local: guides/webhooks_server + title: Webhooks server - title: "Conceptual guides" sections: - local: concepts/git_vs_http @@ -52,6 +56,8 @@ title: Mixins & serialization methods - local: package_reference/inference_api title: Inference API + - local: package_reference/hf_file_system + title: HfFileSystem - local: package_reference/utilities title: Utilities - local: package_reference/community @@ -62,3 +68,5 @@ title: Repo Cards and Repo Card Data - local: package_reference/space_runtime title: Space runtime + - local: package_reference/webhooks_server + title: Webhooks server \ No newline at end of file diff --git a/docs/source/guides/hf_file_system.mdx b/docs/source/guides/hf_file_system.mdx new file mode 100644 index 0000000000..7d0d5581a3 --- /dev/null +++ b/docs/source/guides/hf_file_system.mdx @@ -0,0 +1,105 @@ +# Interact with the Hub through the Filesystem API + +In addition to the [`HfApi`], the `huggingface_hub` library provides [`HfFileSystem`], a pythonic [fsspec-compatible](https://filesystem-spec.readthedocs.io/en/latest/) file interface to the Hugging Face Hub. The [`HfFileSystem`] builds of top of the [`HfApi`] and offers typical filesystem style operations like `cp`, `mv`, `ls`, `du`, `glob`, `get_file`, and `put_file`. + +## Usage + +```python +>>> from huggingface_hub import HfFileSystem +>>> fs = HfFileSystem() + +>>> # List all files in a directory +>>> fs.ls("datasets/my-username/my-dataset-repo/data", detail=False) +['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] + +>>> # List all ".csv" files in a repo +>>> fs.glob("datasets/my-username/my-dataset-repo/**.csv") +['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] + +>>> # Read a remote file +>>> with fs.open("datasets/my-username/my-dataset-repo/data/train.csv", "r") as f: +... train_data = f.readlines() + +>>> # Read the content of a remote file as a string +>>> train_data = fs.read_text("datasets/my-username/my-dataset-repo/data/train.csv", revision="dev") + +>>> # Write a remote file +>>> with fs.open("datasets/my-username/my-dataset-repo/data/validation.csv", "w") as f: +... f.write("text,label") +... f.write("Fantastic movie!,good") +``` + +The optional `revision` argument can be passed to run an operation from a specific commit such as a branch, tag name, or a commit hash. + +Unlike Python's built-in `open`, `fsspec`'s `open` defaults to binary mode, `"rb"`. This means you must explicitly set mode as `"r"` for reading and `"w"` for writing in text mode. Appending to a file (modes `"a"` and `"ab"`) is not supported yet. + +## Integrations + +The [`HfFileSystem`] can be used with any library that integrates `fsspec`, provided the URL follows the scheme: + +``` +hf://[][@]/ +``` + +The `repo_type_prefix` is `datasets/` for datasets, `spaces/` for spaces, and models don't need a prefix in the URL. + +Some interesting integrations where [`HfFileSystem`] simplifies interacting with the Hub are listed below: + +* Reading/writing a [Pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#reading-writing-remote-files) DataFrame from/to a Hub repository: + + ```python + >>> import pandas as pd + + >>> # Read a remote CSV file into a dataframe + >>> df = pd.read_csv("hf://datasets/my-username/my-dataset-repo/train.csv") + + >>> # Write a dataframe to a remote CSV file + >>> df.to_csv("hf://datasets/my-username/my-dataset-repo/test.csv") + ``` + +The same workflow can also be used for [Dask](https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html) and [Polars](https://pola-rs.github.io/polars/py-polars/html/reference/io.html) DataFrames. + +* Querying (remote) Hub files with [DuckDB](https://duckdb.org/docs/guides/python/filesystems): + + ```python + >>> from huggingface_hub import HfFileSystem + >>> import duckdb + + >>> fs = HfFileSystem() + >>> duckdb.register_filesystem(fs) + >>> # Query a remote file and get the result back as a dataframe + >>> fs_query_file = "hf://datasets/my-username/my-dataset-repo/data_dir/data.parquet" + >>> df = duckdb.query(f"SELECT * FROM '{fs_query_file}' LIMIT 10").df() + ``` + +* Using the Hub as an array store with [Zarr](https://zarr.readthedocs.io/en/stable/tutorial.html#io-with-fsspec): + + ```python + >>> import numpy as np + >>> import zarr + + >>> embeddings = np.random.randn(50000, 1000).astype("float32") + + >>> # Write an array to a repo + >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="w") as root: + ... foo = root.create_group("embeddings") + ... foobar = foo.zeros('experiment_0', shape=(50000, 1000), chunks=(10000, 1000), dtype='f4') + ... foobar[:] = embeddings + + >>> # Read an array from a repo + >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="r") as root: + ... first_row = root["embeddings/experiment_0"][0] + ``` + +## Authentication + +In many cases, you must be logged in with a Hugging Face account to interact with the Hub. Refer to the [Login](../quick-start#login) section of the documentation to learn more about authentication methods on the Hub. + +It is also possible to login programmatically by passing your `token` as an argument to [`HfFileSystem`]: + +```python +>>> from huggingface_hub import HfFileSystem +>>> fs = HfFileSystem(token=token) +``` + +If you login this way, be careful not to accidentally leak the token when sharing your source code! diff --git a/docs/source/guides/overview.mdx b/docs/source/guides/overview.mdx index 96820925a5..dcda8154bd 100644 --- a/docs/source/guides/overview.mdx +++ b/docs/source/guides/overview.mdx @@ -42,6 +42,15 @@ Take a look at these guides to learn how to use huggingface_hub to solve real-wo

+ +
+ HfFileSystem +

+ How to interact with the Hub through a convenient interface that mimics Python's file interface? +

+
+ \ No newline at end of file diff --git a/docs/source/guides/webhooks_server.mdx b/docs/source/guides/webhooks_server.mdx new file mode 100644 index 0000000000..55219281eb --- /dev/null +++ b/docs/source/guides/webhooks_server.mdx @@ -0,0 +1,195 @@ +# Webhooks Server + +Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on specific repos or to +all repos belonging to particular users/organizations you're interested in following. This guide will explain how to +leverage `huggingface_hub` to create a server listening to webhooks and deploy it to a Space. It assumes you are +familiar with the concept of webhooks on the Huggingface Hub. To learn more about webhooks themselves, you can read +this [guide](https://huggingface.co/docs/hub/webhooks) first. + +The base class that we will use in this guide is [`WebhooksServer`]. It is a class for easily configuring a server that +can receive webhooks from the Huggingface Hub. The server is based on a [Gradio](https://gradio.app/) app. It has a UI +to display instructions for you or your users and an API to listen to webhooks. + + + +To see a running example of a webhook server, check out the [Spaces CI Bot](https://huggingface.co/spaces/spaces-ci-bot/webhook) +one. It is a Space that launches ephemeral environments when a PR is opened on a Space. + + + + + +This is an [experimental feature](../package_reference/environment_variables#hfhubdisableexperimentalwarning). This +means that we are still working on improving the API. Breaking changes might be introduced in the future without prior +notice. Make sure to pin the version of `huggingface_hub` in your requirements. + + + + +## Create an endpoint + +Implementing a webhook endpoint is as simple as decorating a function. Let's see a first example to explain the main +concepts: + +```python +# app.py +from huggingface_hub import webhook_endpoint, WebhookPayload + +@webhook_endpoint +async def trigger_training(payload: WebhookPayload) -> None: + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... +``` + +Save this snippet in a file called `'app.py'` and run it with `'python app.py'`. You should see a message like this: + +```text +Webhook secret is not defined. This means your webhook endpoints will be open to everyone. +To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: + `app = WebhooksServer(webhook_secret='my_secret', ...)` +For more details about webhook secrets, please refer to https://huggingface.co/docs/hub/webhooks#webhook-secret. +Running on local URL: http://127.0.0.1:7860 +Running on public URL: https://1fadb0f52d8bf825fc.gradio.live + +This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces + +Webhooks are correctly setup and ready to use: + - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training +Go to https://huggingface.co/settings/webhooks to setup your webhooks. +``` + +Good job! You just launched a webhook server! Let's break down what happened exactly: + +1. By decorating a function with [`webhook_endpoint`], a [`WebhooksServer`] object has been created in the background. +As you can see, this server is a Gradio app running on http://127.0.0.1:7860. If you open this URL in your browser, you +will see a landing page with instructions about the registered webhooks. +2. A Gradio app is a FastAPI server under the hood. A new POST route `/webhooks/trigger_training` has been added to it. +This is the route that will listen to webhooks and run the `trigger_training` function when triggered. FastAPI will +automatically parse the payload and pass it to the function as a [`WebhookPayload`] object. This is a `pydantic` object +that contains all the information about the event that triggered the webhook. +3. The Gradio app also opened a tunnel to receive requests from the internet. This is the interesting part: you can +configure a Webhook on https://huggingface.co/settings/webhooks pointing to your local machine. This is useful for +debugging your webhook server and quickly iterating before deploying it to a Space. +4. Finally, the logs also tell you that your server is currently not secured by a secret. This is not problematic for +local debugging but is to keep in mind for later. + + + +By default, the server is started at the end of your script. If you are running it in a notebook, you can start the +server manually by calling `decorated_function.run()`. Since a unique server is used, you only have to start the server +once even if you have multiple endpoints. + + + + +## Configure a Webhook + +Now that you have a webhook server running, you want to configure a Webhook to start receiving messages. +Go to https://huggingface.co/settings/webhooks, click on "Add a new webhook" and configure your Webhook. Set the target +repositories you want to watch and the Webhook URL, here `https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training`. + +
+ +
+ +And that's it! You can now trigger that webhook by updating the target repository (e.g. push a commit). Check the +Activity tab of your Webhook to see the events that have been triggered. Now that you have a working setup, you can +test it and quickly iterate. If you modify your code and restart the server, your public URL might change. Make sure +to update the webhook configuration on the Hub if needed. + +## Deploy to a Space + +Now that you have a working webhook server, the goal is to deploy it to a Space. Go to https://huggingface.co/new-space +to create a Space. Give it a name, select the Gradio SDK and click on "Create Space". Upload your code to the Space +in a file called `app.py`. Your Space will start automatically! For more details about Spaces, please refer to this +[guide](https://huggingface.co/docs/hub/spaces-overview). + +Your webhook server is now running on a public Space. If most cases, you will want to secure it with a secret. Go to +your Space settings > Section "Repository secrets" > "Add a secret". Set the `WEBHOOK_SECRET` environment variable to +the value of your choice. Go back to the [Webhooks settings](https://huggingface.co/settings/webhooks) and set the +secret in the webhook configuration. Now, only requests with the correct secret will be accepted by your server. + +And this is it! Your Space is now ready to receive webhooks from the Hub. Please keep in mind that if you run the Space +on a free 'cpu-basic' hardware, it will be shut down after 48 hours of inactivity. If you need a permanent Space, you +should consider setting to an [upgraded hardware](https://huggingface.co/docs/hub/spaces-gpus#hardware-specs). + +## Advanced usage + +The guide above explained the quickest way to setup a [`WebhooksServer`]. In this section, we will see how to customize +it further. + +### Multiple endpoints + +You can register multiple endpoints on the same server. For example, you might want to have one endpoint to trigger +a training job and another one to trigger a model evaluation. You can do this by adding multiple `@webhook_endpoint` +decorators: + +```python +# app.py +from huggingface_hub import webhook_endpoint, WebhookPayload + +@webhook_endpoint +async def trigger_training(payload: WebhookPayload) -> None: + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + +@webhook_endpoint +async def trigger_evaluation(payload: WebhookPayload) -> None: + if payload.repo.type == "model" and payload.event.action == "update": + # Trigger an evaluation job if a model is updated + ... +``` + +Which will create two endpoints: + +```text +(...) +Webhooks are correctly setup and ready to use: + - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_training + - POST https://1fadb0f52d8bf825fc.gradio.live/webhooks/trigger_evaluation +``` + +### Custom server + +To get more flexibility, you can also create a [`WebhooksServer`] object directly. This is useful if you want to +customize the landing page of your server. You can do this by passing a [Gradio UI](https://gradio.app/docs/#blocks) +that will overwrite the default one. For example, you can add instructions for your users or add a form to manually +trigger the webhooks. When creating a [`WebhooksServer`], you can register new webhooks using the +[`~WebhooksServer.add_webhook`] decorator. + +Here is a complete example: + +```python +import gradio as gr +from fastapi import Request +from huggingface_hub import WebhooksServer, WebhookPayload + +# 1. Define UI +with gr.Blocks() as ui: + ... + +# 2. Create WebhooksServer with custom UI and secret +app = WebhooksServer(ui=ui, webhook_secret="my_secret_key") + +# 3. Register webhook with explicit name +@app.add_webhook("/say_hello") +async def hello(payload: WebhookPayload): + return {"message": "hello"} + +# 4. Register webhook with implicit name +@app.add_webhook +async def goodbye(payload: WebhookPayload): + return {"message": "goodbye"} + +# 5. Start server (optional) +app.run() +``` + +1. We define a custom UI using Gradio blocks. This UI will be displayed on the landing page of the server. +2. We create a [`WebhooksServer`] object with a custom UI and a secret. The secret is optional and can be set with +the `WEBHOOK_SECRET` environment variable. +3. We register a webhook with an explicit name. This will create an endpoint at `/webhooks/say_hello`. +4. We register a webhook with an implicit name. This will create an endpoint at `/webhooks/goodbye`. +5. We start the server. This is optional as your server will automatically be started at the end of the script. \ No newline at end of file diff --git a/docs/source/package_reference/environment_variables.mdx b/docs/source/package_reference/environment_variables.mdx index eb0a0bd4e2..4a99b2743d 100644 --- a/docs/source/package_reference/environment_variables.mdx +++ b/docs/source/package_reference/environment_variables.mdx @@ -115,6 +115,14 @@ to disable this warning. For more details, see [cache limitations](../guides/manage-cache#limitations). +### HF_HUB_DISABLE_EXPERIMENTAL_WARNING + +Some features of `huggingface_hub` are experimental. This means you can use them but we do not guarantee they will be +maintained in the future. In particular, we might update the API or behavior of such features without any deprecation +cycle. A warning message is triggered when using an experimental feature to warn you about it. If you're comfortable debugging any potential issues using an experimental feature, you can set `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` to disable the warning. + +If you are using an experimental feature, please let us know! Your feedback can help us design and improve it. + ### HF_HUB_DISABLE_TELEMETRY By default, some data is collected by HF libraries (`transformers`, `datasets`, `gradio`,..) to monitor usage, debug issues and help prioritize features. diff --git a/docs/source/package_reference/hf_file_system.mdx b/docs/source/package_reference/hf_file_system.mdx new file mode 100644 index 0000000000..17c9258d75 --- /dev/null +++ b/docs/source/package_reference/hf_file_system.mdx @@ -0,0 +1,12 @@ +# Filesystem API + +The `HfFileSystem` class provides a pythonic file interface to the Hugging Face Hub based on [`fssepc`](https://filesystem-spec.readthedocs.io/en/latest/). + +## HfFileSystem + +`HfFileSystem` is based on [fsspec](https://filesystem-spec.readthedocs.io/en/latest/), so it is compatible with most of the APIs that it offers. For more details, check out [our guide](../guides/filesystem) and the fsspec's [API Reference](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem). + +[[autodoc]] HfFileSystem + - __init__ + - resolve_path + - ls diff --git a/docs/source/package_reference/webhooks_server.mdx b/docs/source/package_reference/webhooks_server.mdx new file mode 100644 index 0000000000..9e747eff04 --- /dev/null +++ b/docs/source/package_reference/webhooks_server.mdx @@ -0,0 +1,80 @@ +# Webhooks Server + +Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on specific repos or to +all repos belonging to particular users/organizations you're interested in following. To learn +more about webhooks on the Huggingface Hub, you can read the Webhooks [guide](https://huggingface.co/docs/hub/webhooks). + + + +Check out this [guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your webhooks server and +deploy it as a Space. + + + + + +This is an experimental feature. This means that we are still working on improving the API. Breaking changes might be +introduced in the future without prior notice. Make sure to pin the version of `huggingface_hub` in your requirements. +A warning is triggered when you use an experimental feature. You can disable it by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as an environment variable. + + + +## Server + +The server is a [Gradio](https://gradio.app/) app. It has a UI to display instructions for you or your users and an API +to listen to webhooks. Implementing a webhook endpoint is as simple as decorating a function. You can then debug it +by redirecting the Webhooks to your machine (using a Gradio tunnel) before deploying it to a Space. + +### WebhooksServer + +[[autodoc]] huggingface_hub.WebhooksServer + +### @webhook_endpoint + +[[autodoc]] huggingface_hub.webhook_endpoint + +## Payload + +[`WebhookPayload`] is the main data structure that contains the payload from Webhooks. This is +a `pydantic` class which makes it very easy to use with FastAPI. If you pass it as a parameter to a webhook endpoint, it +will be automatically validated and parsed as a Python object. + +For more information about webhooks payload, you can refer to the Webhooks Payload [guide](https://huggingface.co/docs/hub/webhooks#webhook-payloads). + +[[autodoc]] huggingface_hub.WebhookPayload + +### WebhookPayload + +[[autodoc]] huggingface_hub.WebhookPayload + +### WebhookPayloadComment + +[[autodoc]] huggingface_hub.WebhookPayloadComment + +### WebhookPayloadDiscussion + +[[autodoc]] huggingface_hub.WebhookPayloadDiscussion + +### WebhookPayloadDiscussionChanges + +[[autodoc]] huggingface_hub.WebhookPayloadDiscussionChanges + +### WebhookPayloadEvent + +[[autodoc]] huggingface_hub.WebhookPayloadEvent + +### WebhookPayloadMovedTo + +[[autodoc]] huggingface_hub.WebhookPayloadMovedTo + +### WebhookPayloadRepo + +[[autodoc]] huggingface_hub.WebhookPayloadRepo + +### WebhookPayloadUrl + +[[autodoc]] huggingface_hub.WebhookPayloadUrl + +### WebhookPayloadWebhook + +[[autodoc]] huggingface_hub.WebhookPayloadWebhook diff --git a/setup.cfg b/setup.cfg index 5d4938d997..9cc27b091c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,7 @@ known_third_party = faiss-cpu fastprogress fire + fsspec fugashi git graphviz diff --git a/setup.py b/setup.py index 60cf5afbeb..8e1261a75c 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ def get_version() -> str: install_requires = [ "filelock", + "fsspec", "requests", "tqdm>=4.42.1", "pyyaml>=5.1", @@ -40,6 +41,7 @@ def get_version() -> str: extras["tensorflow"] = ["tensorflow", "pydot", "graphviz"] + extras["testing"] = extras["cli"] + [ "jedi", "Jinja2", @@ -49,6 +51,7 @@ def get_version() -> str: "pytest-xdist", "soundfile", "Pillow", + "gradio", # to test webhooks ] # Typing extra dependencies list is duplicated in `.pre-commit-config.yaml` @@ -87,7 +90,10 @@ def get_version() -> str: package_dir={"": "src"}, packages=find_packages("src"), extras_require=extras, - entry_points={"console_scripts": ["huggingface-cli=huggingface_hub.commands.huggingface_cli:main"]}, + entry_points={ + "console_scripts": ["huggingface-cli=huggingface_hub.commands.huggingface_cli:main"], + "fsspec.specs": "hf=huggingface_hub.HfFileSystem", + }, python_requires=">=3.7.0", install_requires=install_requires, classifiers=[ diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 527b1362a3..2045ae66d1 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -70,6 +70,21 @@ "SpaceRuntime", "SpaceStage", ], + "_webhooks_payload": [ + "WebhookPayload", + "WebhookPayloadComment", + "WebhookPayloadDiscussion", + "WebhookPayloadDiscussionChanges", + "WebhookPayloadEvent", + "WebhookPayloadMovedTo", + "WebhookPayloadRepo", + "WebhookPayloadUrl", + "WebhookPayloadWebhook", + ], + "_webhooks_server": [ + "WebhooksServer", + "webhook_endpoint", + ], "community": [ "Discussion", "DiscussionComment", @@ -167,6 +182,11 @@ "upload_folder", "whoami", ], + "hf_file_system": [ + "HfFileSystem", + "HfFileSystemFile", + "HfFileSystemResolvedPath", + ], "hub_mixin": [ "ModelHubMixin", "PyTorchModelHubMixin", @@ -334,6 +354,21 @@ def __dir__(): SpaceRuntime, # noqa: F401 SpaceStage, # noqa: F401 ) + from ._webhooks_payload import ( + WebhookPayload, # noqa: F401 + WebhookPayloadComment, # noqa: F401 + WebhookPayloadDiscussion, # noqa: F401 + WebhookPayloadDiscussionChanges, # noqa: F401 + WebhookPayloadEvent, # noqa: F401 + WebhookPayloadMovedTo, # noqa: F401 + WebhookPayloadRepo, # noqa: F401 + WebhookPayloadUrl, # noqa: F401 + WebhookPayloadWebhook, # noqa: F401 + ) + from ._webhooks_server import ( + WebhooksServer, # noqa: F401 + webhook_endpoint, # noqa: F401 + ) from .community import ( Discussion, # noqa: F401 DiscussionComment, # noqa: F401 @@ -431,6 +466,11 @@ def __dir__(): upload_folder, # noqa: F401 whoami, # noqa: F401 ) + from .hf_file_system import ( + HfFileSystem, # noqa: F401 + HfFileSystemFile, # noqa: F401 + HfFileSystemResolvedPath, # noqa: F401 + ) from .hub_mixin import ( ModelHubMixin, # noqa: F401 PyTorchModelHubMixin, # noqa: F401 diff --git a/src/huggingface_hub/_login.py b/src/huggingface_hub/_login.py index b7693b307b..12228b32ae 100644 --- a/src/huggingface_hub/_login.py +++ b/src/huggingface_hub/_login.py @@ -22,6 +22,7 @@ from .hf_api import HfApi from .utils import ( HfFolder, + capture_output, is_google_colab, is_notebook, list_credential_helpers, @@ -180,7 +181,7 @@ def notebook_login() -> None: """ try: import ipywidgets.widgets as widgets # type: ignore - from IPython.display import clear_output, display # type: ignore + from IPython.display import display # type: ignore except ImportError: raise ImportError( "The `notebook_login` function can only be used in a notebook (Jupyter or" @@ -211,8 +212,16 @@ def login_token_event(t): add_to_git_credential = git_checkbox_widget.value # Erase token and clear value to make sure it's not saved in the notebook. token_widget.value = "" - clear_output() - _login(token, add_to_git_credential=add_to_git_credential) + # Hide inputs + login_token_widget.children = [widgets.Label("Connecting...")] + try: + with capture_output() as captured: + _login(token, add_to_git_credential=add_to_git_credential) + message = captured.getvalue() + except Exception as error: + message = str(error) + # Print result (success message or error) + login_token_widget.children = [widgets.Label(line) for line in message.split("\n") if line.strip()] token_finish_button.on_click(login_token_event) @@ -235,13 +244,13 @@ def _login(token: str, add_to_git_credential: bool) -> None: set_git_credential(token) print( "Your token has been saved in your configured git credential helpers" - f" ({','.join(list_credential_helpers())})." + + f" ({','.join(list_credential_helpers())})." ) else: print("Token has not been saved to git credential helper.") HfFolder.save_token(token) - print("Your token has been saved to", HfFolder.path_token) + print(f"Your token has been saved to {HfFolder.path_token}") print("Login successful") diff --git a/src/huggingface_hub/_webhooks_payload.py b/src/huggingface_hub/_webhooks_payload.py new file mode 100644 index 0000000000..43ed693f38 --- /dev/null +++ b/src/huggingface_hub/_webhooks_payload.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains data structures to parse the webhooks payload.""" +from typing import List, Optional + +from .utils import is_gradio_available +from .utils._typing import Literal + + +if not is_gradio_available(): + raise ImportError( + "You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio` first." + ) + +from pydantic import BaseModel + + +# This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they +# are not in used anymore. To keep in sync when format is updated in +# https://github.com/huggingface/moon-landing/blob/main/server/lib/HFWebhooks.ts (internal link). + + +WebhookEvent_T = Literal[ + "create", + "delete", + "move", + "update", +] +RepoChangeEvent_T = Literal[ + "add", + "move", + "remove", + "update", +] +RepoType_T = Literal[ + "dataset", + "model", + "space", +] +DiscussionStatus_T = Literal[ + "closed", + "draft", + "open", + "merged", +] +SupportedWebhookVersion = Literal[3] + + +class ObjectId(BaseModel): + id: str + + +class WebhookPayloadUrl(BaseModel): + web: str + api: Optional[str] + + +class WebhookPayloadMovedTo(BaseModel): + name: str + owner: ObjectId + + +class WebhookPayloadWebhook(ObjectId): + version: SupportedWebhookVersion + + +class WebhookPayloadEvent(BaseModel): + action: WebhookEvent_T + scope: str + + +class WebhookPayloadDiscussionChanges(BaseModel): + base: str + mergeCommitId: Optional[str] + + +class WebhookPayloadComment(ObjectId): + author: ObjectId + hidden: bool + content: Optional[str] + url: WebhookPayloadUrl + + +class WebhookPayloadDiscussion(ObjectId): + num: int + author: ObjectId + url: WebhookPayloadUrl + title: str + isPullRequest: bool + status: DiscussionStatus_T + changes: Optional[WebhookPayloadDiscussionChanges] + pinned: Optional[bool] + + +class WebhookPayloadRepo(ObjectId): + owner: ObjectId + head_sha: Optional[str] + name: str + private: bool + subdomain: Optional[str] + tags: Optional[List[str]] + type: Literal["dataset", "model", "space"] + url: WebhookPayloadUrl + + +class WebhookPayload(BaseModel): + event: WebhookPayloadEvent + repo: WebhookPayloadRepo + discussion: Optional[WebhookPayloadDiscussion] + comment: Optional[WebhookPayloadComment] + webhook: WebhookPayloadWebhook + movedTo: Optional[WebhookPayloadMovedTo] diff --git a/src/huggingface_hub/_webhooks_server.py b/src/huggingface_hub/_webhooks_server.py new file mode 100644 index 0000000000..96ff011864 --- /dev/null +++ b/src/huggingface_hub/_webhooks_server.py @@ -0,0 +1,362 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains `WebhooksServer` and `webhook_endpoint` to create a webhook server easily.""" +import atexit +import inspect +import os +from functools import wraps +from typing import Callable, Dict, Optional + +from .utils import experimental, is_gradio_available + + +if not is_gradio_available(): + raise ImportError( + "You must have `gradio` installed to use `WebhooksServer`. Please run `pip install --upgrade gradio` first." + ) + + +import gradio as gr +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + + +_global_app: Optional["WebhooksServer"] = None +_is_local = os.getenv("SYSTEM") != "spaces" + + +@experimental +class WebhooksServer: + """ + The [`WebhooksServer`] class lets you create an instance of a Gradio app that can receive Huggingface webhooks. + These webhooks can be registered using the [`~WebhooksServer.add_webhook`] decorator. Webhook endpoints are added to + the app as a POST endpoint to the FastAPI router. Once all the webhooks are registered, the `run` method has to be + called to start the app. + + It is recommended to accept [`WebhookPayload`] as the first argument of the webhook function. It is a Pydantic + model that contains all the information about the webhook event. The data will be parsed automatically for you. + + Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your + WebhooksServer and deploy it on a Space. + + + + `WebhooksServer` is experimental. Its API is subject to change in the future. + + + + + + You must have `gradio` installed to use `WebhooksServer` (`pip install --upgrade gradio`). + + + + Args: + ui (`gradio.Blocks`, optional): + A Gradio UI instance to be used as the Space landing page. If `None`, a UI displaying instructions + about the configured webhooks is created. + webhook_secret (`str`, optional): + A secret key to verify incoming webhook requests. You can set this value to any secret you want as long as + you also configure it in your [webhooks settings panel](https://huggingface.co/settings/webhooks). You + can also set this value as the `WEBHOOK_SECRET` environment variable. If no secret is provided, the + webhook endpoints are opened without any security. + + Example: + + ```python + import gradio as gr + from huggingface_hub import WebhooksServer, WebhookPayload + + with gr.Blocks() as ui: + ... + + app = WebhooksServer(ui=ui, webhook_secret="my_secret_key") + + @app.add_webhook("/say_hello") + async def hello(payload: WebhookPayload): + return {"message": "hello"} + + app.run() + ``` + """ + + def __init__( + self, + ui: Optional[gr.Blocks] = None, + webhook_secret: Optional[str] = None, + ) -> None: + self._ui = ui + + self.webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET") + self.registered_webhooks: Dict[str, Callable] = {} + _warn_on_empty_secret(self.webhook_secret) + + def add_webhook(self, path: Optional[str] = None) -> Callable: + """ + Decorator to add a webhook to the [`WebhooksServer`] server. + + Args: + path (`str`, optional): + The URL path to register the webhook function. If not provided, the function name will be used as the + path. In any case, all webhooks are registered under `/webhooks`. + + Raises: + ValueError: If the provided path is already registered as a webhook. + + Example: + ```python + from huggingface_hub import WebhooksServer, WebhookPayload + + app = WebhooksServer() + + @app.add_webhook + async def trigger_training(payload: WebhookPayload): + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + + app.run() + ``` + """ + # Usage: directly as decorator. Example: `@app.add_webhook` + if callable(path): + # If path is a function, it means it was used as a decorator without arguments + return self.add_webhook()(path) + + # Usage: provide a path. Example: `@app.add_webhook(...)` + @wraps(FastAPI.post) + def _inner_post(*args, **kwargs): + func = args[0] + abs_path = f"/webhooks/{(path or func.__name__).strip('/')}" + if abs_path in self.registered_webhooks: + raise ValueError(f"Webhook {abs_path} already exists.") + self.registered_webhooks[abs_path] = func + + return _inner_post + + def run(self) -> None: + """Starts the Gradio app with the FastAPI server and registers the webhooks.""" + ui = self._ui or self._get_default_ui() + + # Start Gradio App + # - as non-blocking so that webhooks can be added afterwards + # - as shared if launch locally (to debug webhooks) + self.fastapi_app, _, _ = ui.launch(prevent_thread_lock=True, share=_is_local) + + # Register webhooks to FastAPI app + for path, func in self.registered_webhooks.items(): + # Add secret check if required + if self.webhook_secret is not None: + func = _wrap_webhook_to_check_secret(func, webhook_secret=self.webhook_secret) + + # Add route to FastAPI app + self.fastapi_app.post(path)(func) + + # Print instructions and block main thread + url = (ui.share_url or ui.local_url).strip("/") + message = "\nWebhooks are correctly setup and ready to use:" + message += "\n" + "\n".join(f" - POST {url}{webhook}" for webhook in self.registered_webhooks) + message += "\nGo to https://huggingface.co/settings/webhooks to setup your webhooks." + print(message) + + ui.block_thread() + + def _get_default_ui(self) -> gr.Blocks: + """Default UI if not provided (lists webhooks and provides basic instructions).""" + with gr.Blocks() as ui: + gr.Markdown("# This is an app to process 🤗 Webhooks") + gr.Markdown( + "Webhooks are a foundation for MLOps-related features. They allow you to listen for new changes on" + " specific repos or to all repos belonging to particular set of users/organizations (not just your" + " repos, but any repo). Check out this [guide](https://huggingface.co/docs/hub/webhooks) to get to" + " know more about webhooks on the Huggingface Hub." + ) + gr.Markdown( + f"{len(self.registered_webhooks)} webhook(s) are registered:" + + "\n\n" + + "\n ".join( + f"- [{webhook_path}]({_get_webhook_doc_url(webhook.__name__, webhook_path)})" + for webhook_path, webhook in self.registered_webhooks.items() + ) + ) + gr.Markdown( + "Go to https://huggingface.co/settings/webhooks to setup your webhooks." + + "\nYou app is running locally. Please look at the logs to check the full URL you need to set." + if _is_local + else ( + "\nThis app is running on a Space. You can find the corresponding URL in the options menu" + " (top-right) > 'Embed the Space'. The URL looks like 'https://{username}-{repo_name}.hf.space'." + ) + ) + return ui + + +@experimental +def webhook_endpoint(path: Optional[str] = None) -> Callable: + """Decorator to start a [`WebhooksServer`] and register the decorated function as a webhook endpoint. + + This is a helper to get started quickly. If you need more flexibility (custom landing page or webhook secret), + you can use [`WebhooksServer`] directly. You can register multiple webhook endpoints (to the same server) by using + this decorator multiple times. + + Check out the [webhooks guide](../guides/webhooks_server) for a step-by-step tutorial on how to setup your + server and deploy it on a Space. + + + + `webhook_endpoint` is experimental. Its API is subject to change in the future. + + + + + + You must have `gradio` installed to use `webhook_endpoint` (`pip install --upgrade gradio`). + + + + Args: + path (`str`, optional): + The URL path to register the webhook function. If not provided, the function name will be used as the path. + In any case, all webhooks are registered under `/webhooks`. + + Examples: + The default usage is to register a function as a webhook endpoint. The function name will be used as the path. + The server will be started automatically at exit (i.e. at the end of the script). + + ```python + from huggingface_hub import webhook_endpoint, WebhookPayload + + @webhook_endpoint + async def trigger_training(payload: WebhookPayload): + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + + # Server is automatically started at the end of the script. + ``` + + Advanced usage: register a function as a webhook endpoint and start the server manually. This is useful if you + are running it in a notebook. + + ```python + from huggingface_hub import webhook_endpoint, WebhookPayload + + @webhook_endpoint + async def trigger_training(payload: WebhookPayload): + if payload.repo.type == "dataset" and payload.event.action == "update": + # Trigger a training job if a dataset is updated + ... + + # Start the server manually + trigger_training.run() + ``` + """ + if callable(path): + # If path is a function, it means it was used as a decorator without arguments + return webhook_endpoint()(path) + + @wraps(WebhooksServer.add_webhook) + def _inner(func: Callable) -> Callable: + app = _get_global_app() + app.add_webhook(path)(func) + if len(app.registered_webhooks) == 1: + # Register `app.run` to run at exit (only once) + atexit.register(app.run) + + @wraps(app.run) + def _run_now(): + # Run the app directly (without waiting atexit) + atexit.unregister(app.run) + app.run() + + func.run = _run_now # type: ignore + return func + + return _inner + + +def _get_global_app() -> WebhooksServer: + global _global_app + if _global_app is None: + _global_app = WebhooksServer() + return _global_app + + +def _warn_on_empty_secret(webhook_secret: Optional[str]) -> None: + if webhook_secret is None: + print("Webhook secret is not defined. This means your webhook endpoints will be open to everyone.") + print( + "To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: " + "\n\t`app = WebhooksServer(webhook_secret='my_secret', ...)`" + ) + print( + "For more details about webhook secrets, please refer to" + " https://huggingface.co/docs/hub/webhooks#webhook-secret." + ) + else: + print("Webhook secret is correctly defined.") + + +def _get_webhook_doc_url(webhook_name: str, webhook_path: str) -> str: + """Returns the anchor to a given webhook in the docs (experimental)""" + return "/docs#/default/" + webhook_name + webhook_path.replace("/", "_") + "_post" + + +def _wrap_webhook_to_check_secret(func: Callable, webhook_secret: str) -> Callable: + """Wraps a webhook function to check the webhook secret before calling the function. + + This is a hacky way to add the `request` parameter to the function signature. Since FastAPI based itself on route + parameters to inject the values to the function, we need to hack the function signature to retrieve the `Request` + object (and hence the headers). A far cleaner solution would be to use a middleware. However, since + `fastapi==0.90.1`, a middleware cannot be added once the app has started. And since the FastAPI app is started by + Gradio internals (and not by us), we cannot add a middleware. + + This method is called only when a secret has been defined by the user. If a request is sent without the + "x-webhook-secret", the function will return a 401 error (unauthorized). If the header is sent but is incorrect, + the function will return a 403 error (forbidden). + + Inspired by https://stackoverflow.com/a/33112180. + """ + initial_sig = inspect.signature(func) + + @wraps(func) + async def _protected_func(request: Request, **kwargs): + request_secret = request.headers.get("x-webhook-secret") + if request_secret is None: + return JSONResponse({"error": "x-webhook-secret header not set."}, status_code=401) + if request_secret != webhook_secret: + return JSONResponse({"error": "Invalid webhook secret."}, status_code=403) + + # Inject `request` in kwargs if required + if "request" in initial_sig.parameters: + kwargs["request"] = request + + # Handle both sync and async routes + if inspect.iscoroutinefunction(func): + return await func(**kwargs) + else: + return func(**kwargs) + + # Update signature to include request + if "request" not in initial_sig.parameters: + _protected_func.__signature__ = initial_sig.replace( # type: ignore + parameters=( + inspect.Parameter(name="request", kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request), + ) + + tuple(initial_sig.parameters.values()) + ) + + # Return protected route + return _protected_func diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index 2d2bbd8129..e08b6a5add 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -110,6 +110,9 @@ def _as_int(value: Optional[str]) -> Optional[int]: # Disable warning on machines that do not support symlinks (e.g. Windows non-developer) HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING")) +# Disable warning when using experimental features +HF_HUB_DISABLE_EXPERIMENTAL_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_EXPERIMENTAL_WARNING")) + # Disable sending the cached token by default is all HTTP requests to the Hub HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN")) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 89a8a90521..9d594ccebf 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -878,7 +878,13 @@ def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None: relative_src = None try: - _support_symlinks = are_symlinks_supported(os.path.dirname(os.path.commonpath([abs_src, abs_dst]))) + try: + commonpath = os.path.commonpath([abs_src, abs_dst]) + _support_symlinks = are_symlinks_supported(os.path.dirname(commonpath)) + except ValueError: + # Raised if src and dst are not on the same volume. Symlinks will still work on Linux/Macos. + # See https://docs.python.org/3/library/os.path.html#os.path.commonpath + _support_symlinks = os.name != "nt" except PermissionError: # Permission error means src and dst are not in the same volume (e.g. destination path has been provided # by the user via `local_dir`. Let's test symlink support there) @@ -900,7 +906,7 @@ def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None: raise elif new_blob: logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}") - os.replace(src, dst) + shutil.move(src, dst) else: logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}") shutil.copyfile(src, dst) @@ -1140,11 +1146,17 @@ def hf_hub_download( # cross platform transcription of filename, to be used as a local file path. relative_filename = os.path.join(*filename.split("/")) + if os.name == "nt": + if relative_filename.startswith("..\\") or "\\..\\" in relative_filename: + raise ValueError( + f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository" + " owner to rename this file." + ) # if user provides a commit_hash and they already have the file on disk, # shortcut everything. if REGEX_COMMIT_HASH.match(revision): - pointer_path = os.path.join(storage_folder, "snapshots", revision, relative_filename) + pointer_path = _get_pointer_path(storage_folder, revision, relative_filename) if os.path.exists(pointer_path): if local_dir is not None: return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks) @@ -1239,7 +1251,7 @@ def hf_hub_download( # Return pointer file if exists if commit_hash is not None: - pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename) + pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename) if os.path.exists(pointer_path): if local_dir is not None: return _to_local_dir( @@ -1268,7 +1280,7 @@ def hf_hub_download( assert etag is not None, "etag must have been retrieved from server" assert commit_hash is not None, "commit_hash must have been retrieved from server" blob_path = os.path.join(storage_folder, "blobs", etag) - pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename) + pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename) os.makedirs(os.path.dirname(blob_path), exist_ok=True) os.makedirs(os.path.dirname(pointer_path), exist_ok=True) @@ -1554,7 +1566,20 @@ def _chmod_and_replace(src: str, dst: str) -> None: finally: tmp_file.unlink() - os.replace(src, dst) + shutil.move(src, dst) + + +def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str: + # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks + snapshot_path = os.path.join(storage_folder, "snapshots") + pointer_path = os.path.join(snapshot_path, revision, relative_filename) + if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents: + raise ValueError( + "Invalid pointer path: cannot create pointer path in snapshot folder if" + f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and" + f" `relative_filename='{relative_filename}'`." + ) + return pointer_path def _to_local_dir( @@ -1564,7 +1589,14 @@ def _to_local_dir( Either symlink to blob file in cache or duplicate file depending on `use_symlinks` and file size. """ + # Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks local_dir_filepath = os.path.join(local_dir, relative_filename) + if Path(os.path.abspath(local_dir)) not in Path(os.path.abspath(local_dir_filepath)).parents: + raise ValueError( + f"Cannot copy file '{relative_filename}' to local dir '{local_dir}': file would not be in the local" + " directory." + ) + os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True) real_blob_path = os.path.realpath(path) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 5e07d07b32..be303fe2f2 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -78,6 +78,7 @@ filter_repo_objects, hf_raise_for_status, logging, + paginate, parse_datetime, validate_hf_hub_args, ) @@ -85,7 +86,6 @@ _deprecate_arguments, _deprecate_list_output, ) -from .utils._pagination import paginate from .utils._typing import Literal, TypedDict from .utils.endpoint_helpers import ( AttributeDictionary, diff --git a/src/huggingface_hub/hf_file_system.py b/src/huggingface_hub/hf_file_system.py new file mode 100644 index 0000000000..25235f6183 --- /dev/null +++ b/src/huggingface_hub/hf_file_system.py @@ -0,0 +1,439 @@ +import itertools +import os +import tempfile +from dataclasses import dataclass +from datetime import datetime +from glob import has_magic +from typing import Any, Dict, List, Optional, Tuple, Union +from urllib.parse import quote, unquote + +import fsspec +import requests + +from ._commit_api import CommitOperationDelete +from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES +from .hf_api import HfApi +from .utils import ( + EntryNotFoundError, + HFValidationError, + RepositoryNotFoundError, + RevisionNotFoundError, + hf_raise_for_status, + http_backoff, + paginate, + parse_datetime, +) + + +@dataclass +class HfFileSystemResolvedPath: + """Data structure containing information about a resolved hffs path.""" + + repo_type: str + repo_id: str + revision: str + path_in_repo: str + + def unresolve(self) -> str: + return ( + f"{REPO_TYPES_URL_PREFIXES.get(self.repo_type, '') + self.repo_id}@{safe_quote(self.revision)}/{self.path_in_repo}" + .rstrip("/") + ) + + +class HfFileSystem(fsspec.AbstractFileSystem): + """ + Access a remote Hugging Face Hub repository as if were a local file system. + + Args: + endpoint (`str`, *optional*): + The endpoint to use. If not provided, the default one (https://huggingface.co) is used. + token (`str`, *optional*): + Authentication token, obtained with [`HfApi.login`] method. Will default to the stored token. + + Usage: + + ```python + >>> import hffs + + >>> fs = hffs.HfFileSystem() + + >>> # List files + >>> fs.glob("my-username/my-model/*.bin") + ['my-username/my-model/pytorch_model.bin'] + >>> fs.ls("datasets/my-username/my-dataset", detail=False) + ['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json'] + + >>> # Read/write files + >>> with fs.open("my-username/my-model/pytorch_model.bin") as f: + ... data = f.read() + >>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f: + ... f.write(data) + ``` + """ + + root_marker = "" + protocol = "hf" + + def __init__( + self, + *args, + endpoint: Optional[str] = None, + token: Optional[str] = None, + **storage_options, + ): + super().__init__(*args, **storage_options) + self.endpoint = endpoint or ENDPOINT + self.token = token + self._api = HfApi(endpoint=endpoint, token=token) + # Maps (repo_type, repo_id, revision) to a 2-tuple with: + # * the 1st element indicating whether the repositoy and the revision exist + # * the 2nd element being the exception raised if the repository or revision doesn't exist + self._repo_and_revision_exists_cache: Dict[ + Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]] + ] = {} + + def _repo_and_revision_exist( + self, repo_type: str, repo_id: str, revision: Optional[str] + ) -> Tuple[bool, Optional[Exception]]: + if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: + try: + self._api.repo_info(repo_id, revision=revision, repo_type=repo_type) + except (RepositoryNotFoundError, HFValidationError) as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e + except RevisionNotFoundError as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + else: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] + + def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath: + def _align_revision_in_path_with_revision( + revision_in_path: Optional[str], revision: Optional[str] + ) -> Optional[str]: + if revision is not None: + if revision_in_path is not None and revision_in_path != revision: + raise ValueError( + f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")' + " are not the same." + ) + else: + revision = revision_in_path + return revision + + path = self._strip_protocol(path) + if not path: + # can't list repositories at root + raise NotImplementedError("Access to repositories lists is not implemented.") + elif path.split("/")[0] + "/" in REPO_TYPES_URL_PREFIXES.values(): + if "/" not in path: + # can't list repositories at the repository type level + raise NotImplementedError("Acces to repositories lists is not implemented.") + repo_type, path = path.split("/", 1) + repo_type = REPO_TYPES_MAPPING[repo_type] + else: + repo_type = REPO_TYPE_MODEL + if path.count("/") > 0: + if "@" in path: + repo_id, revision_in_path = path.split("@", 1) + if "/" in revision_in_path: + revision_in_path, path_in_repo = revision_in_path.split("/", 1) + else: + path_in_repo = "" + revision_in_path = unquote(revision_in_path) + revision = _align_revision_in_path_with_revision(revision_in_path, revision) + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + raise FileNotFoundError(path) from err + else: + repo_id_with_namespace = "/".join(path.split("/")[:2]) + path_in_repo_with_namespace = "/".join(path.split("/")[2:]) + repo_id_without_namespace = path.split("/")[0] + path_in_repo_without_namespace = "/".join(path.split("/")[1:]) + repo_id = repo_id_with_namespace + path_in_repo = path_in_repo_with_namespace + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + if isinstance(err, (RepositoryNotFoundError, HFValidationError)): + repo_id = repo_id_without_namespace + path_in_repo = path_in_repo_without_namespace + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + raise FileNotFoundError(path) from err + else: + raise FileNotFoundError(path) from err + else: + repo_id = path + path_in_repo = "" + if "@" in path: + repo_id, revision_in_path = path.split("@", 1) + revision_in_path = unquote(revision_in_path) + revision = _align_revision_in_path_with_revision(revision_in_path, revision) + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + raise NotImplementedError("Acces to repositories lists is not implemented.") + + revision = revision if revision is not None else DEFAULT_REVISION + return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo) + + def invalidate_cache(self, path: Optional[str] = None) -> None: + if not path: + self.dircache.clear() + self._repository_type_and_id_exists_cache.clear() + else: + path = self.resolve_path(path).unresolve() + while path: + self.dircache.pop(path, None) + path = self._parent(path) + + def _open( + self, + path: str, + mode: str = "rb", + revision: Optional[str] = None, + **kwargs, + ) -> "HfFileSystemFile": + if mode == "ab": + raise NotImplementedError("Appending to remote files is not yet supported.") + return HfFileSystemFile(self, path, mode=mode, revision=revision, **kwargs) + + def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None: + resolved_path = self.resolve_path(path, revision=revision) + self._api.delete_file( + path_in_repo=resolved_path.path_in_repo, + repo_id=resolved_path.repo_id, + token=self.token, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message"), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def rm( + self, + path: str, + recursive: bool = False, + maxdepth: Optional[int] = None, + revision: Optional[str] = None, + **kwargs, + ) -> None: + resolved_path = self.resolve_path(path, revision=revision) + root_path = REPO_TYPES_URL_PREFIXES.get(resolved_path.repo_type, "") + resolved_path.repo_id + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=resolved_path.revision) + paths_in_repo = [path[len(root_path) + 1 :] for path in paths if not self.isdir(path)] + operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo] + commit_message = f"Delete {path} " + commit_message += "recursively " if recursive else "" + commit_message += f"up to depth {maxdepth} " if maxdepth is not None else "" + # TODO: use `commit_description` to list all the deleted paths? + self._api.create_commit( + repo_id=resolved_path.repo_id, + repo_type=resolved_path.repo_type, + token=self.token, + operations=operations, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def ls( + self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs + ) -> List[Union[str, Dict[str, Any]]]: + """List the contents of a directory.""" + resolved_path = self.resolve_path(path, revision=revision) + revision_in_path = "@" + safe_quote(resolved_path.revision) + has_revision_in_path = revision_in_path in path + path = resolved_path.unresolve() + if path not in self.dircache or refresh: + path_prefix = ( + HfFileSystemResolvedPath( + resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision, "" + ).unresolve() + + "/" + ) + tree_path = path + tree_iter = self._iter_tree(tree_path, revision=resolved_path.revision) + try: + tree_item = next(tree_iter) + except EntryNotFoundError: + if "/" in resolved_path.path_in_repo: + tree_path = self._parent(path) + tree_iter = self._iter_tree(tree_path, revision=resolved_path.revision) + else: + raise + else: + tree_iter = itertools.chain([tree_item], tree_iter) + child_infos = [] + for tree_item in tree_iter: + child_info = { + "name": path_prefix + tree_item["path"], + "size": tree_item["size"], + "type": tree_item["type"], + } + if tree_item["type"] == "file": + child_info.update( + { + "blob_id": tree_item["oid"], + "lfs": tree_item.get("lfs"), + "last_modified": parse_datetime(tree_item["lastCommit"]["date"]), + }, + ) + child_infos.append(child_info) + self.dircache[tree_path] = child_infos + out = self._ls_from_cache(path) + if not has_revision_in_path: + out = [{**o, "name": o["name"].replace(revision_in_path, "", 1)} for o in out] + return out if detail else [o["name"] for o in out] + + def _iter_tree(self, path: str, revision: Optional[str] = None): + resolved_path = self.resolve_path(path, revision=revision) + path = f"{self._api.endpoint}/api/{resolved_path.repo_type}s/{resolved_path.repo_id}/tree/{safe_quote(resolved_path.revision)}/{resolved_path.path_in_repo}".rstrip( + "/" + ) + headers = self._api._build_hf_headers() + yield from paginate(path, params={}, headers=headers) + + def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None: + resolved_path1 = self.resolve_path(path1, revision=revision) + resolved_path2 = self.resolve_path(path2, revision=revision) + + same_repo = ( + resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id + ) + + # TODO: Wait for https://github.com/huggingface/huggingface_hub/issues/1083 to be resolved to simplify this logic + if same_repo and self.info(path1, revision=resolved_path1.revision)["lfs"] is not None: + headers = self._api._build_hf_headers(is_write_action=True) + commit_message = f"Copy {path1} to {path2}" + payload = { + "summary": kwargs.get("commit_message", commit_message), + "description": kwargs.get("commit_description", ""), + "files": [], + "lfsFiles": [ + { + "path": resolved_path2.path_in_repo, + "algo": "sha256", + "oid": self.info(path1, revision=resolved_path1.revision)["lfs"]["oid"], + } + ], + "deletedFiles": [], + } + r = requests.post( + f"{self.endpoint}/api/{resolved_path1.repo_type}s/{resolved_path1.repo_id}/commit/{safe_quote(resolved_path2.revision)}", + json=payload, + headers=headers, + ) + hf_raise_for_status(r) + else: + with self.open(path1, "rb", revision=resolved_path1.revision) as f: + content = f.read() + commit_message = f"Copy {path1} to {path2}" + self._api.upload_file( + path_or_fileobj=content, + path_in_repo=resolved_path2.path_in_repo, + repo_id=resolved_path2.repo_id, + token=self.token, + repo_type=resolved_path2.repo_type, + revision=resolved_path2.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path1.unresolve()) + self.invalidate_cache(path=resolved_path2.unresolve()) + + def modified(self, path: str, **kwargs) -> datetime: + info = self.info(path, **kwargs) + if "last_modified" not in info: + raise IsADirectoryError(path) + return info["last_modified"] + + def info(self, path: str, **kwargs) -> Dict[str, Any]: + resolved_path = self.resolve_path(path) + if not resolved_path.path_in_repo: + revision_in_path = "@" + safe_quote(resolved_path.revision) + has_revision_in_path = revision_in_path in path + name = resolved_path.unresolve() + name = name.replace(revision_in_path, "", 1) if not has_revision_in_path else name + return {"name": name, "size": 0, "type": "directory"} + return super().info(path, **kwargs) + + def expand_path( + self, path: Union[str, List[str]], recursive: bool = False, maxdepth: Optional[int] = None, **kwargs + ) -> List[str]: + # The default implementation does not allow passing custom kwargs (e.g., we use these kwargs to propagate the `revision`) + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + + if isinstance(path, str): + return self.expand_path([path], recursive, maxdepth) + + out = set() + path = [self._strip_protocol(p) for p in path] + for p in path: + if has_magic(p): + bit = set(self.glob(p, **kwargs)) + out |= bit + if recursive: + out |= set(self.expand_path(list(bit), recursive=recursive, maxdepth=maxdepth, **kwargs)) + continue + elif recursive: + rec = set(self.find(p, maxdepth=maxdepth, withdirs=True, detail=False, **kwargs)) + out |= rec + if p not in out and (recursive is False or self.exists(p)): + # should only check once, for the root + out.add(p) + if not out: + raise FileNotFoundError(path) + return list(sorted(out)) + + +class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): + def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): + super().__init__(fs, path, **kwargs) + self.fs: HfFileSystem + self.resolved_path = fs.resolve_path(path, revision=revision) + + def _fetch_range(self, start: int, end: int) -> bytes: + headers = { + "range": f"bytes={start}-{end - 1}", + **self.fs._api._build_hf_headers(), + } + url = ( + f"{self.fs.endpoint}/{REPO_TYPES_URL_PREFIXES.get(self.resolved_path.repo_type, '') + self.resolved_path.repo_id}/resolve/{safe_quote(self.resolved_path.revision)}/{safe_quote(self.resolved_path.path_in_repo)}" + ) + r = http_backoff("GET", url, headers=headers) + hf_raise_for_status(r) + return r.content + + def _initiate_upload(self) -> None: + self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False) + + def _upload_chunk(self, final: bool = False) -> None: + self.buffer.seek(0) + block = self.buffer.read() + self.temp_file.write(block) + if final: + self.temp_file.close() + self.fs._api.upload_file( + path_or_fileobj=self.temp_file.name, + path_in_repo=self.resolved_path.path_in_repo, + repo_id=self.resolved_path.repo_id, + token=self.fs.token, + repo_type=self.resolved_path.repo_type, + revision=self.resolved_path.revision, + commit_message=self.kwargs.get("commit_message"), + commit_description=self.kwargs.get("commit_description"), + ) + os.remove(self.temp_file.name) + self.fs.invalidate_cache( + path=self.resolved_path.unresolve(), + ) + + +def safe_quote(s: str) -> str: + return quote(s, safe="") diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index afa91a4609..427eabb56c 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -260,7 +260,7 @@ def push_to_hub( Args: repo_id (`str`): - Repository name to which push. + ID of the repository to push to (example: `"username/my-model"`). config (`dict`, *optional*): Configuration object to be saved alongside the model weights. commit_message (`str`, *optional*): diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index d456e20c88..32ea4091e0 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -306,7 +306,7 @@ def push_to_hub_keras( The [Keras model](`https://www.tensorflow.org/api_docs/python/tf/keras/Model`) you'd like to push to the Hub. The model must be compiled and built. repo_id (`str`): - Repository name to which push + ID of the repository to push to (example: `"username/my-model"`). commit_message (`str`, *optional*, defaults to "Add Keras model"): Message to commit while pushing. private (`bool`, *optional*, defaults to `False`): diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py index 81850bc507..757995870b 100644 --- a/src/huggingface_hub/repository.py +++ b/src/huggingface_hub/repository.py @@ -1452,7 +1452,7 @@ def wait_for_commands(self): while self.commands_in_progress: if index % 10 == 0: - logger.error( + logger.warning( f"Waiting for the following commands to finish before shutting down: {self.commands_in_progress}." ) diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index bfdf881b87..cec24731b5 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -44,11 +44,15 @@ from ._headers import build_hf_headers, get_token_to_send from ._hf_folder import HfFolder from ._http import configure_http_backend, get_session, http_backoff +from ._pagination import paginate from ._paths import filter_repo_objects, IGNORE_GIT_FOLDER_PATTERNS +from ._experimental import experimental from ._runtime import ( dump_environment_info, get_fastai_version, get_fastcore_version, + get_gradio_version, + is_gradio_available, get_graphviz_version, get_hf_hub_version, get_hf_transfer_version, @@ -70,7 +74,7 @@ is_tf_available, is_torch_available, ) -from ._subprocess import run_interactive_subprocess, run_subprocess +from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess from ._validators import ( HFValidationError, smoothly_deprecate_use_auth_token, diff --git a/src/huggingface_hub/utils/_experimental.py b/src/huggingface_hub/utils/_experimental.py new file mode 100644 index 0000000000..3e4ebda89f --- /dev/null +++ b/src/huggingface_hub/utils/_experimental.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to flag a feature as "experimental" in Huggingface Hub.""" +import warnings +from functools import wraps +from typing import Callable + +from .. import constants + + +def experimental(fn: Callable) -> Callable: + """Decorator to flag a feature as experimental. + + An experimental feature trigger a warning when used as it might be subject to breaking changes in the future. + Warnings can be disabled by setting the environment variable `HF_EXPERIMENTAL_WARNING` to `0`. + + Args: + fn (`Callable`): + The function to flag as experimental. + + Returns: + `Callable`: The decorated function. + + Example: + + ```python + >>> from huggingface_hub.utils import experimental + + >>> @experimental + ... def my_function(): + ... print("Hello world!") + + >>> my_function() + UserWarning: 'my_function' is experimental and might be subject to breaking changes in the future. You can disable + this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable. + Hello world! + ``` + """ + + @wraps(fn) + def _inner_fn(*args, **kwargs): + if not constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING: + warnings.warn( + ( + f"'{fn.__name__}' is experimental and might be subject to breaking changes in the future. You can" + " disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment" + " variable." + ), + UserWarning, + ) + return fn(*args, **kwargs) + + return _inner_fn diff --git a/src/huggingface_hub/utils/_runtime.py b/src/huggingface_hub/utils/_runtime.py index 6699461de5..560b02933d 100644 --- a/src/huggingface_hub/utils/_runtime.py +++ b/src/huggingface_hub/utils/_runtime.py @@ -35,6 +35,7 @@ _CANDIDATES = { "torch": {"torch"}, "pydot": {"pydot"}, + "gradio": {"gradio"}, "graphviz": {"graphviz"}, "tensorflow": ( "tensorflow", @@ -102,6 +103,15 @@ def get_fastcore_version() -> str: return _get_version("fastcore") +# FastAI +def is_gradio_available() -> bool: + return _is_available("gradio") + + +def get_gradio_version() -> str: + return _get_version("gradio") + + # Graphviz def is_graphviz_available() -> bool: return _is_available("graphviz") @@ -254,6 +264,7 @@ def dump_environment_info() -> Dict[str, Any]: info["Pydot"] = get_pydot_version() info["Pillow"] = get_pillow_version() info["hf_transfer"] = get_hf_transfer_version() + info["gradio"] = get_gradio_version() # Environment variables info["ENDPOINT"] = constants.ENDPOINT @@ -264,6 +275,7 @@ def dump_environment_info() -> Dict[str, Any]: info["HF_HUB_DISABLE_TELEMETRY"] = constants.HF_HUB_DISABLE_TELEMETRY info["HF_HUB_DISABLE_PROGRESS_BARS"] = constants.HF_HUB_DISABLE_PROGRESS_BARS info["HF_HUB_DISABLE_SYMLINKS_WARNING"] = constants.HF_HUB_DISABLE_SYMLINKS_WARNING + info["HF_HUB_DISABLE_EXPERIMENTAL_WARNING"] = constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING info["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = constants.HF_HUB_DISABLE_IMPLICIT_TOKEN info["HF_HUB_ENABLE_HF_TRANSFER"] = constants.HF_HUB_ENABLE_HF_TRANSFER diff --git a/src/huggingface_hub/utils/_subprocess.py b/src/huggingface_hub/utils/_subprocess.py index c9de3e32f0..5ec7936549 100644 --- a/src/huggingface_hub/utils/_subprocess.py +++ b/src/huggingface_hub/utils/_subprocess.py @@ -16,7 +16,9 @@ """Contains utilities to easily handle subprocesses in `huggingface_hub`.""" import os import subprocess +import sys from contextlib import contextmanager +from io import StringIO from pathlib import Path from typing import IO, Generator, List, Optional, Tuple, Union @@ -26,6 +28,26 @@ logger = get_logger(__name__) +@contextmanager +def capture_output() -> Generator[StringIO, None, None]: + """Capture output that is printed to terminal. + + Taken from https://stackoverflow.com/a/34738440 + + Example: + ```py + >>> with capture_output() as output: + ... print("hello world") + >>> assert output.getvalue() == "hello world\n" + ``` + """ + output = StringIO() + previous_output = sys.stdout + sys.stdout = output + yield output + sys.stdout = previous_output + + def run_subprocess( command: Union[str, List[str]], folder: Optional[Union[str, Path]] = None, diff --git a/tests/conftest.py b/tests/conftest.py index 34fab9168a..6238a46e16 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,6 +70,11 @@ def __getitem__(self, __key: str) -> bool: ) +@pytest.fixture(autouse=True) +def disable_experimental_warnings(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(huggingface_hub.constants, "HF_HUB_DISABLE_EXPERIMENTAL_WARNING", True) + + @pytest.fixture def fx_production_space(request: SubRequest) -> Generator[None, None, None]: """Add a `repo_id` attribute referencing a Space repo on the production Hub. diff --git a/tests/test_command_delete_cache.py b/tests/test_command_delete_cache.py index 201bc20f5e..f5d3d90208 100644 --- a/tests/test_command_delete_cache.py +++ b/tests/test_command_delete_cache.py @@ -16,9 +16,9 @@ _manual_review_no_tui, _read_manual_review_tmp_file, ) -from huggingface_hub.utils import SoftTemporaryDirectory +from huggingface_hub.utils import SoftTemporaryDirectory, capture_output -from .testing_utils import capture_output, handle_injection +from .testing_utils import handle_injection class TestDeleteCacheHelpers(unittest.TestCase): diff --git a/tests/test_file_download.py b/tests/test_file_download.py index dad7858c05..1a9b8d84f1 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -32,7 +32,9 @@ from huggingface_hub.file_download import ( _CACHED_NO_EXIST, _create_symlink, + _get_pointer_path, _request_wrapper, + _to_local_dir, cached_download, filename_to_url, get_hf_file_metadata, @@ -772,6 +774,69 @@ def test_hf_hub_download_on_awful_subfolder_and_filename(self): self.assertTrue(local_path.endswith(self.filepath)) +@pytest.mark.usefixtures("fx_cache_dir") +class TestHfHubDownloadRelativePaths(unittest.TestCase): + """Regression test for HackerOne report 1928845. + + Issue was that any file outside of the local dir could be overwritten (Windows only). + + In the end, multiple protections have been added to prevent this (..\\ in filename forbidden on Windows, always check + the filepath is in local_dir/snapshot_dir). + """ + + cache_dir: Path + + @classmethod + def setUpClass(cls): + cls.api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) + cls.repo_id = cls.api.create_repo(repo_id=repo_name()).repo_id + cls.api.upload_file(path_or_fileobj=b"content", path_in_repo="..\\ddd", repo_id=cls.repo_id) + cls.api.upload_file(path_or_fileobj=b"content", path_in_repo="folder/..\\..\\..\\file", repo_id=cls.repo_id) + + @classmethod + def tearDownClass(cls) -> None: + cls.api.delete_repo(repo_id=cls.repo_id) + + @xfail_on_windows(reason="Windows paths cannot start with '..\\'.", raises=ValueError) + def test_download_file_in_cache_dir(self) -> None: + hf_hub_download(self.repo_id, "..\\ddd", cache_dir=self.cache_dir) + + @xfail_on_windows(reason="Windows paths cannot start with '..\\'.", raises=ValueError) + def test_download_file_to_local_dir(self) -> None: + with SoftTemporaryDirectory() as local_dir: + hf_hub_download(self.repo_id, "..\\ddd", cache_dir=self.cache_dir, local_dir=local_dir) + + @xfail_on_windows(reason="Windows paths cannot contain '\\..\\'.", raises=ValueError) + def test_download_folder_file_in_cache_dir(self) -> None: + hf_hub_download(self.repo_id, "folder/..\\..\\..\\file", cache_dir=self.cache_dir) + + @xfail_on_windows(reason="Windows paths cannot contain '\\..\\'.", raises=ValueError) + def test_download_folder_file_to_local_dir(self) -> None: + with SoftTemporaryDirectory() as local_dir: + hf_hub_download(self.repo_id, "folder/..\\..\\..\\file", cache_dir=self.cache_dir, local_dir=local_dir) + + def test_get_pointer_path_and_valid_relative_filename(self) -> None: + # Cannot happen because of other protections, but just in case. + self.assertEqual( + _get_pointer_path("path/to/storage", "abcdef", "path/to/file.txt"), + os.path.join("path/to/storage", "snapshots", "abcdef", "path/to/file.txt"), + ) + + def test_get_pointer_path_but_invalid_relative_filename(self) -> None: + # Cannot happen because of other protections, but just in case. + relative_filename = "folder\\..\\..\\..\\file.txt" if os.name == "nt" else "folder/../../../file.txt" + with self.assertRaises(ValueError): + _get_pointer_path("path/to/storage", "abcdef", relative_filename) + + def test_to_local_dir_but_invalid_relative_filename(self) -> None: + # Cannot happen because of other protections, but just in case. + relative_filename = "folder\\..\\..\\..\\file.txt" if os.name == "nt" else "folder/../../../file.txt" + with self.assertRaises(ValueError): + _to_local_dir( + "path/to/file_to_copy", "path/to/local/dir", relative_filename=relative_filename, use_symlinks=False + ) + + class CreateSymlinkTest(unittest.TestCase): @unittest.skipIf(os.name == "nt", "No symlinks on Windows") @patch("huggingface_hub.file_download.are_symlinks_supported") diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 4fb23874fa..71a76040a3 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -234,13 +234,9 @@ def test_move_repo_normal_usage(self): repo_id = f"{USER}/{repo_name()}" new_repo_id = f"{USER}/{repo_name()}" - for repo_type in [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]: - self._api.create_repo( - repo_id=repo_id, - repo_type=repo_type, - space_sdk="static" if repo_type == REPO_TYPE_SPACE else None, - ) - # Should raise an error if it fails + # Spaces not tested on staging (error 500) + for repo_type in [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET]: + self._api.create_repo(repo_id=repo_id, repo_type=repo_type) self._api.move_repo(from_id=repo_id, to_id=new_repo_id, repo_type=repo_type) self._api.delete_repo(repo_id=new_repo_id, repo_type=repo_type) @@ -2273,7 +2269,7 @@ def test_pause_and_restart_space(self) -> None: self.assertEqual(runtime_after_pause.stage, SpaceStage.PAUSED) self.api.restart_space(self.repo_id) - time.sleep(0.2) + time.sleep(1.0) runtime_after_restart = self.api.get_space_runtime(self.repo_id) self.assertIn(runtime_after_restart.stage, (SpaceStage.BUILDING, SpaceStage.RUNNING_BUILDING)) diff --git a/tests/test_hf_file_system.py b/tests/test_hf_file_system.py new file mode 100644 index 0000000000..7a40e1402a --- /dev/null +++ b/tests/test_hf_file_system.py @@ -0,0 +1,284 @@ +import datetime +import unittest +from typing import Optional +from unittest.mock import patch + +import fsspec +import pytest + +from huggingface_hub.constants import REPO_TYPES_URL_PREFIXES +from huggingface_hub.hf_file_system import HfFileSystem +from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError + +from .testing_constants import ENDPOINT_STAGING, TOKEN, USER +from .testing_utils import repo_name, retry_endpoint + + +class HfFileSystemTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Register `HfFileSystem` as a `fsspec` filesystem if not already registered.""" + if HfFileSystem.protocol not in fsspec.available_protocols(): + fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem) + + def setUp(self): + self.repo_id = f"{USER}/{repo_name()}" + self.repo_type = "dataset" + self.hf_path = REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id + self.hffs = HfFileSystem(endpoint=ENDPOINT_STAGING, token=TOKEN) + self.api = self.hffs._api + + # Create dummy repo + self.api.create_repo(self.repo_id, repo_type=self.repo_type) + self.api.upload_file( + path_or_fileobj=b"dummy binary data on pr", + path_in_repo="data/binary_data_for_pr.bin", + repo_id=self.repo_id, + repo_type=self.repo_type, + create_pr=True, + ) + self.api.upload_file( + path_or_fileobj="dummy text data".encode("utf-8"), + path_in_repo="data/text_data.txt", + repo_id=self.repo_id, + repo_type=self.repo_type, + ) + self.api.upload_file( + path_or_fileobj=b"dummy binary data", + path_in_repo="data/binary_data.bin", + repo_id=self.repo_id, + repo_type=self.repo_type, + ) + + def tearDown(self): + self.api.delete_repo(self.repo_id, repo_type=self.repo_type) + + @retry_endpoint + def test_glob(self): + self.assertEqual( + sorted(self.hffs.glob(self.hf_path + "/*")), + sorted([self.hf_path + "/.gitattributes", self.hf_path + "/data"]), + ) + + self.assertEqual( + sorted(self.hffs.glob(self.hf_path + "/*", revision="main")), + sorted([self.hf_path + "/.gitattributes", self.hf_path + "/data"]), + ) + self.assertEqual( + sorted(self.hffs.glob(self.hf_path + "@main" + "/*")), + sorted([self.hf_path + "@main" + "/.gitattributes", self.hf_path + "@main" + "/data"]), + ) + + @retry_endpoint + def test_file_type(self): + self.assertTrue( + self.hffs.isdir(self.hf_path + "/data") and not self.hffs.isdir(self.hf_path + "/.gitattributes") + ) + self.assertTrue( + self.hffs.isfile(self.hf_path + "/data/text_data.txt") and not self.hffs.isfile(self.hf_path + "/data") + ) + + @retry_endpoint + def test_remove_file(self): + self.hffs.rm_file(self.hf_path + "/data/text_data.txt") + self.assertEqual(self.hffs.glob(self.hf_path + "/data/*"), [self.hf_path + "/data/binary_data.bin"]) + + @retry_endpoint + def test_remove_directory(self): + self.hffs.rm(self.hf_path + "/data", recursive=True) + self.assertNotIn(self.hf_path + "/data", self.hffs.ls(self.hf_path)) + + @retry_endpoint + def test_read_file(self): + with self.hffs.open(self.hf_path + "/data/text_data.txt", "r") as f: + self.assertEqual(f.read(), "dummy text data") + + @retry_endpoint + def test_write_file(self): + data = "new text data" + with self.hffs.open(self.hf_path + "/data/new_text_data.txt", "w") as f: + f.write(data) + self.assertIn(self.hf_path + "/data/new_text_data.txt", self.hffs.glob(self.hf_path + "/data/*")) + with self.hffs.open(self.hf_path + "/data/new_text_data.txt", "r") as f: + self.assertEqual(f.read(), data) + + @retry_endpoint + def test_write_file_multiple_chunks(self): + # TODO: try with files between 10 and 50MB (as of 16 March 2023 I was getting 504 errors on hub-ci) + data = "a" * (4 << 20) # 4MB + with self.hffs.open(self.hf_path + "/data/new_text_data_big.txt", "w") as f: + for _ in range(2): # 8MB in total + f.write(data) + + self.assertIn(self.hf_path + "/data/new_text_data_big.txt", self.hffs.glob(self.hf_path + "/data/*")) + with self.hffs.open(self.hf_path + "/data/new_text_data_big.txt", "r") as f: + for _ in range(2): + self.assertEqual(f.read(len(data)), data) + + @unittest.skip("Not implemented yet") + @retry_endpoint + def test_append_file(self): + with self.hffs.open(self.hf_path + "/data/text_data.txt", "a") as f: + f.write(" appended text") + + with self.hffs.open(self.hf_path + "/data/text_data.txt", "r") as f: + self.assertEqual(f.read(), "dummy text data appended text") + + @retry_endpoint + def test_copy_file(self): + # Non-LFS file + self.assertIsNone(self.hffs.info(self.hf_path + "/data/text_data.txt")["lfs"]) + self.hffs.cp_file(self.hf_path + "/data/text_data.txt", self.hf_path + "/data/text_data_copy.txt") + with self.hffs.open(self.hf_path + "/data/text_data_copy.txt", "r") as f: + self.assertEqual(f.read(), "dummy text data") + self.assertIsNone(self.hffs.info(self.hf_path + "/data/text_data_copy.txt")["lfs"]) + # LFS file + self.assertIsNotNone(self.hffs.info(self.hf_path + "/data/binary_data.bin")["lfs"]) + self.hffs.cp_file(self.hf_path + "/data/binary_data.bin", self.hf_path + "/data/binary_data_copy.bin") + with self.hffs.open(self.hf_path + "/data/binary_data_copy.bin", "rb") as f: + self.assertEqual(f.read(), b"dummy binary data") + self.assertIsNotNone(self.hffs.info(self.hf_path + "/data/binary_data_copy.bin")["lfs"]) + + @retry_endpoint + def test_modified_time(self): + self.assertIsInstance(self.hffs.modified(self.hf_path + "/data/text_data.txt"), datetime.datetime) + # should fail on a non-existing file + with self.assertRaises(FileNotFoundError): + self.hffs.modified(self.hf_path + "/data/not_existing_file.txt") + # should fail on a directory + with self.assertRaises(IsADirectoryError): + self.hffs.modified(self.hf_path + "/data") + + @retry_endpoint + def test_initialize_from_fsspec(self): + fs, _, paths = fsspec.get_fs_token_paths( + f"hf://{self.repo_type}s/{self.repo_id}/data/text_data.txt", + storage_options={ + "endpoint": ENDPOINT_STAGING, + "token": TOKEN, + }, + ) + self.assertIsInstance(fs, HfFileSystem) + self.assertEqual(fs._api.endpoint, ENDPOINT_STAGING) + self.assertEqual(fs.token, TOKEN) + self.assertEqual(paths, [self.hf_path + "/data/text_data.txt"]) + + fs, _, paths = fsspec.get_fs_token_paths(f"hf://{self.repo_id}/data/text_data.txt") + self.assertIsInstance(fs, HfFileSystem) + self.assertEqual(paths, [f"{self.repo_id}/data/text_data.txt"]) + + @retry_endpoint + def test_list_root_directory_no_revision(self): + files = self.hffs.ls(self.hf_path) + self.assertEqual(len(files), 2) + + self.assertEqual(files[0]["type"], "directory") + self.assertEqual(files[0]["size"], 0) + self.assertTrue(files[0]["name"].endswith("/data")) + + self.assertEqual(files[1]["type"], "file") + self.assertGreater(files[1]["size"], 0) # not empty + self.assertTrue(files[1]["name"].endswith("/.gitattributes")) + + @retry_endpoint + def test_list_data_directory_no_revision(self): + files = self.hffs.ls(self.hf_path + "/data") + self.assertEqual(len(files), 2) + + self.assertEqual(files[0]["type"], "file") + self.assertGreater(files[0]["size"], 0) # not empty + self.assertTrue(files[0]["name"].endswith("/data/binary_data.bin")) + self.assertIsNotNone(files[0]["lfs"]) + self.assertIn("oid", files[0]["lfs"]) + self.assertIn("size", files[0]["lfs"]) + self.assertIn("pointerSize", files[0]["lfs"]) + + self.assertEqual(files[1]["type"], "file") + self.assertGreater(files[1]["size"], 0) # not empty + self.assertTrue(files[1]["name"].endswith("/data/text_data.txt")) + self.assertIsNone(files[1]["lfs"]) + + @retry_endpoint + def test_list_data_directory_with_revision(self): + files = self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data") + + for test_name, files in { + "rev_in_path": self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data"), + "rev_as_arg": self.hffs.ls(self.hf_path + "/data", revision="refs/pr/1"), + "rev_in_path_and_as_arg": self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data", revision="refs/pr/1"), + }.items(): + with self.subTest(test_name): + self.assertEqual(len(files), 1) # only one file in PR + self.assertEqual(files[0]["type"], "file") + self.assertTrue(files[0]["name"].endswith("/data/binary_data_for_pr.bin")) # PR file + + +@pytest.mark.parametrize("path_in_repo", ["", "foo"]) +@pytest.mark.parametrize( + "root_path,revision,repo_type,repo_id,resolved_revision", + [ + # Parse without namespace + ("gpt2", None, "model", "gpt2", "main"), + ("gpt2", "dev", "model", "gpt2", "dev"), + ("gpt2@dev", None, "model", "gpt2", "dev"), + ("datasets/squad", None, "dataset", "squad", "main"), + ("datasets/squad", "dev", "dataset", "squad", "dev"), + ("datasets/squad@dev", None, "dataset", "squad", "dev"), + # Parse with namespace + ("username/my_model", None, "model", "username/my_model", "main"), + ("username/my_model", "dev", "model", "username/my_model", "dev"), + ("username/my_model@dev", None, "model", "username/my_model", "dev"), + ("datasets/username/my_dataset", None, "dataset", "username/my_dataset", "main"), + ("datasets/username/my_dataset", "dev", "dataset", "username/my_dataset", "dev"), + ("datasets/username/my_dataset@dev", None, "dataset", "username/my_dataset", "dev"), + # Parse with hf:// protocol + ("hf://gpt2", None, "model", "gpt2", "main"), + ("hf://gpt2", "dev", "model", "gpt2", "dev"), + ("hf://gpt2@dev", None, "model", "gpt2", "dev"), + ("hf://datasets/squad", None, "dataset", "squad", "main"), + ("hf://datasets/squad", "dev", "dataset", "squad", "dev"), + ("hf://datasets/squad@dev", None, "dataset", "squad", "dev"), + ], +) +def test_resolve_path( + root_path: str, + revision: Optional[str], + repo_type: str, + repo_id: str, + resolved_revision: str, + path_in_repo: str, +): + fs = HfFileSystem() + path = root_path + "/" + path_in_repo if path_in_repo else root_path + + def mock_repo_info(repo_id: str, *, revision: str, repo_type: str, **kwargs): + if repo_id not in ["gpt2", "squad", "username/my_dataset", "username/my_model"]: + raise RepositoryNotFoundError(repo_id) + if revision is not None and revision not in ["main", "dev"]: + raise RevisionNotFoundError(revision) + + with patch.object(fs._api, "repo_info", mock_repo_info): + resolved_path = fs.resolve_path(path, revision=revision) + assert ( + resolved_path.repo_type, + resolved_path.repo_id, + resolved_path.revision, + resolved_path.path_in_repo, + ) == (repo_type, repo_id, resolved_revision, path_in_repo) + + +def test_resolve_path_with_non_matching_revisions(): + fs = HfFileSystem() + with pytest.raises(ValueError): + fs.resolve_path("gpt2@dev", revision="main") + + +@pytest.mark.parametrize("not_supported_path", ["", "foo", "datasets", "datasets/foo"]) +def test_access_repositories_lists(not_supported_path): + fs = HfFileSystem() + with pytest.raises(NotImplementedError): + fs.ls(not_supported_path) + with pytest.raises(NotImplementedError): + fs.glob(not_supported_path + "/") + with pytest.raises(NotImplementedError): + fs.open(not_supported_path) diff --git a/tests/test_utils_cache.py b/tests/test_utils_cache.py index 9ab80a26e8..429d49c2d6 100644 --- a/tests/test_utils_cache.py +++ b/tests/test_utils_cache.py @@ -9,7 +9,7 @@ from huggingface_hub._snapshot_download import snapshot_download from huggingface_hub.commands.scan_cache import ScanCacheCommand -from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, scan_cache_dir +from huggingface_hub.utils import DeleteCacheStrategy, HFCacheInfo, capture_output, scan_cache_dir from huggingface_hub.utils._cache_manager import ( CacheNotFound, _format_size, @@ -19,7 +19,6 @@ from .testing_constants import TOKEN from .testing_utils import ( - capture_output, rmtree_with_retry, with_production_testing, xfail_on_windows, diff --git a/tests/test_utils_experimental.py b/tests/test_utils_experimental.py new file mode 100644 index 0000000000..ce83266561 --- /dev/null +++ b/tests/test_utils_experimental.py @@ -0,0 +1,26 @@ +import unittest +import warnings +from unittest.mock import patch + +from huggingface_hub.utils import experimental + + +@experimental +def dummy_function(): + return "success" + + +class TestExperimentalFlag(unittest.TestCase): + def test_experimental_warning(self): + with patch("huggingface_hub.constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING", False): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.assertEqual(dummy_function(), "success") + self.assertEqual(len(w), 1) + + def test_experimental_no_warning(self): + with patch("huggingface_hub.constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING", True): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.assertEqual(dummy_function(), "success") + self.assertEqual(len(w), 0) diff --git a/tests/test_webhooks_server.py b/tests/test_webhooks_server.py new file mode 100644 index 0000000000..cc9aba40ec --- /dev/null +++ b/tests/test_webhooks_server.py @@ -0,0 +1,188 @@ +import unittest +from unittest.mock import patch + +from fastapi import Request + +from huggingface_hub.utils import capture_output, is_gradio_available + +from .testing_utils import require_webhooks + + +if is_gradio_available(): + import gradio as gr + from fastapi.testclient import TestClient + + import huggingface_hub._webhooks_server + from huggingface_hub import WebhookPayload, WebhooksServer + + +# Taken from https://huggingface.co/docs/hub/webhooks#event +WEBHOOK_PAYLOAD_EXAMPLE = { + "event": {"action": "create", "scope": "discussion"}, + "repo": { + "type": "model", + "name": "gpt2", + "id": "621ffdc036468d709f17434d", + "private": False, + "url": {"web": "https://huggingface.co/gpt2", "api": "https://huggingface.co/api/models/gpt2"}, + "owner": {"id": "628b753283ef59b5be89e937"}, + }, + "discussion": { + "id": "6399f58518721fdd27fc9ca9", + "title": "Update co2 emissions", + "url": { + "web": "https://huggingface.co/gpt2/discussions/19", + "api": "https://huggingface.co/api/models/gpt2/discussions/19", + }, + "status": "open", + "author": {"id": "61d2f90c3c2083e1c08af22d"}, + "num": 19, + "isPullRequest": True, + "changes": {"base": "refs/heads/main"}, + }, + "comment": { + "id": "6399f58518721fdd27fc9caa", + "author": {"id": "61d2f90c3c2083e1c08af22d"}, + "content": "Add co2 emissions information to the model card", + "hidden": False, + "url": {"web": "https://huggingface.co/gpt2/discussions/19#6399f58518721fdd27fc9caa"}, + }, + "webhook": {"id": "6390e855e30d9209411de93b", "version": 3}, +} + + +@require_webhooks +class TestWebhooksServerDontRun(unittest.TestCase): + def test_add_webhook_implicit_path(self): + # Test adding a webhook + app = WebhooksServer() + + @app.add_webhook + async def handler(): + pass + + self.assertIn("/webhooks/handler", app.registered_webhooks) + + def test_add_webhook_explicit_path(self): + # Test adding a webhook + app = WebhooksServer() + + @app.add_webhook(path="/test_webhook") + async def handler(): + pass + + self.assertIn("/webhooks/test_webhook", app.registered_webhooks) # still registered under /webhooks + + def test_add_webhook_twice_should_fail(self): + # Test adding a webhook + app = WebhooksServer() + + @app.add_webhook("my_webhook") + async def test_webhook(): + pass + + # Registering twice the same webhook should raise an error + with self.assertRaises(ValueError): + + @app.add_webhook("my_webhook") + async def test_webhook_2(): + pass + + +@require_webhooks +class TestWebhooksServerRun(unittest.TestCase): + HEADERS_VALID_SECRET = {"x-webhook-secret": "my_webhook_secret"} + HEADERS_WRONG_SECRET = {"x-webhook-secret": "wrong_webhook_secret"} + + def setUp(self) -> None: + with gr.Blocks() as ui: + gr.Markdown("Hello World!") + app = WebhooksServer(ui=ui, webhook_secret="my_webhook_secret") + + # Route to check payload parsing + @app.add_webhook + async def test_webhook(payload: WebhookPayload) -> None: + return {"scope": payload.event.scope} + + # Routes to check secret validation + # Checks all 4 cases (async/sync, with/without request parameter) + @app.add_webhook + async def async_with_request(request: Request) -> None: + return {"success": True} + + @app.add_webhook + def sync_with_request(request: Request) -> None: + return {"success": True} + + @app.add_webhook + async def async_no_request() -> None: + return {"success": True} + + @app.add_webhook + def sync_no_request() -> None: + return {"success": True} + + # Route to check explicit path + @app.add_webhook(path="/explicit_path") + async def with_explicit_path() -> None: + return {"success": True} + + self.ui = ui + self.app = app + self.client = self.mocked_run_app() + + def tearDown(self) -> None: + self.ui.server.close() + + def mocked_run_app(self) -> "TestClient": + with patch.object(self.ui, "block_thread"): + # Run without blocking + with patch.object(huggingface_hub._webhooks_server, "_is_local", False): + # Run without tunnel + self.app.run() + return TestClient(self.app.fastapi_app) + + def test_run_print_instructions(self): + """Test that the instructions are printed when running the app.""" + # Test running the app + with capture_output() as output: + self.mocked_run_app() + + instructions = output.getvalue() + self.assertIn("Webhooks are correctly setup and ready to use:", instructions) + self.assertIn("- POST http://127.0.0.1:7860/webhooks/test_webhook", instructions) + + def test_run_parse_payload(self): + """Test that the payload is correctly parsed when running the app.""" + response = self.client.post( + "/webhooks/test_webhook", headers=self.HEADERS_VALID_SECRET, json=WEBHOOK_PAYLOAD_EXAMPLE + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"scope": "discussion"}) + + def test_with_webhook_secret_should_succeed(self): + """Test success if valid secret is sent.""" + for path in ["async_with_request", "sync_with_request", "async_no_request", "sync_no_request"]: + with self.subTest(path): + response = self.client.post(f"/webhooks/{path}", headers=self.HEADERS_VALID_SECRET) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"success": True}) + + def test_no_webhook_secret_should_be_unauthorized(self): + """Test failure if valid secret is sent.""" + for path in ["async_with_request", "sync_with_request", "async_no_request", "sync_no_request"]: + with self.subTest(path): + response = self.client.post(f"/webhooks/{path}") + self.assertEqual(response.status_code, 401) + + def test_wrong_webhook_secret_should_be_forbidden(self): + """Test failure if valid secret is sent.""" + for path in ["async_with_request", "sync_with_request", "async_no_request", "sync_no_request"]: + with self.subTest(path): + response = self.client.post(f"/webhooks/{path}", headers=self.HEADERS_WRONG_SECRET) + self.assertEqual(response.status_code, 403) + + def test_route_with_explicit_path(self): + """Test that the route with an explicit path is correctly registered.""" + response = self.client.post("/webhooks/explicit_path", headers=self.HEADERS_VALID_SECRET) + self.assertEqual(response.status_code, 200) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 2314f603a5..8035a56867 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -2,23 +2,21 @@ import os import shutil import stat -import sys import time import unittest import uuid from contextlib import contextmanager from enum import Enum from functools import wraps -from io import StringIO from pathlib import Path -from typing import Callable, Generator, Optional, Type, TypeVar, Union +from typing import Callable, Optional, Type, TypeVar, Union from unittest.mock import Mock, patch import pytest import requests from requests.exceptions import HTTPError -from huggingface_hub.utils import logging +from huggingface_hub.utils import is_gradio_available, logging from tests.testing_constants import ENDPOINT_PRODUCTION, ENDPOINT_PRODUCTION_URL_SCHEME @@ -106,7 +104,7 @@ def parse_int_from_env(key, default=None): def require_git_lfs(test_case): """ - Decorator marking a test that requires git-lfs. + Decorator to mark tests that requires git-lfs. git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment variable to a truthy value to run them. @@ -117,6 +115,19 @@ def require_git_lfs(test_case): return test_case +def require_webhooks(test_case): + """ + Decorator to mark tests that requires `webhooks` extra (i.e. gradio, fastapi, pydantic). + + git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment + variable to a truthy value to run them. + """ + if not is_gradio_available(): + return unittest.skip("Skip webhook test")(test_case) + else: + return test_case + + class RequestWouldHangIndefinitelyError(Exception): pass @@ -321,30 +332,6 @@ def _inner_decorator(test_function: Callable) -> Callable: return _inner_decorator -@contextmanager -def capture_output() -> Generator[StringIO, None, None]: - """Capture output that is printed to console. - - Especially useful to test CLI commands. - - Taken from https://stackoverflow.com/a/34738440 - - Example: - ```py - class TestHelloWorld(unittest.TestCase): - def test_hello_world(self): - with capture_output() as output: - print("hello world") - self.assertEqual(output.getvalue(), "hello world\n") - ``` - """ - output = StringIO() - previous_output = sys.stdout - sys.stdout = output - yield output - sys.stdout = previous_output - - T = TypeVar("T")