diff --git a/docs/source/guides/upload.mdx b/docs/source/guides/upload.mdx index 19be2c4c61..e76f087db2 100644 --- a/docs/source/guides/upload.mdx +++ b/docs/source/guides/upload.mdx @@ -108,6 +108,54 @@ but before that, all previous logs on the repo on deleted. All of this in a sing ... ) ``` +### Non-blocking upload + +In some cases, you want to push data without blocking your main thread. This is particularly useful to upload logs and +artifacts while continuing a training. To do so, you can use the `run_as_future` argument in both [`upload_file] and +[`upload_folder`]. This will return a [`concurrent.futures.Future`](https://docs.python.org/3/library/concurrent.futures.html#future-objects) +object that you can use to check the status of the upload. + +```py +>>> from huggingface_hub import HfApi +>>> api = HfApi() +>>> future = api.upload_folder( # Upload in the background (non-blocking action) +... repo_id="username/my-model", +... folder_path="checkpoints-001", +... run_as_future=True, +... ) +>>> future +Future(...) +>>> future.done() +False +>>> future.result() # Wait for the upload to complete (blocking action) +... +``` + + + +Background jobs are queued when using `run_as_future=True`. This means that you are guaranteed that the jobs will be +executed in the correct order. + + + +Even though background jobs are mostly useful to upload data/create commits, you can queue any method you like using +[`run_as_future`]. For instance, you can use it to create a repo and then upload data to it in the background. The +built-in `run_as_future` argument in upload methods is just an alias around it. + +```py +>>> from huggingface_hub import HfApi +>>> api = HfApi() +>>> api.run_as_future(api.create_repo, "username/my-model", exists_ok=True) +Future(...) +>>> api.upload_file( +... repo_id="username/my-model", +... path_in_repo="file.txt", +... path_or_fileobj=b"file content", +... run_as_future=True, +... ) +Future(...) +``` + ### Upload a folder by chunks [`upload_folder`] makes it easy to upload an entire folder to the Hub. However, for large folders (thousands of files or diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 600fc80cb7..579e152315 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -176,6 +176,7 @@ "repo_type_and_id_from_hf_id", "request_space_hardware", "restart_space", + "run_as_future", "set_space_sleep_time", "space_info", "unlike", @@ -462,6 +463,7 @@ def __dir__(): repo_type_and_id_from_hf_id, # noqa: F401 request_space_hardware, # noqa: F401 restart_space, # noqa: F401 + run_as_future, # noqa: F401 set_space_sleep_time, # noqa: F401 space_info, # noqa: F401 unlike, # noqa: F401 diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index dccad33b04..58e0884c87 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -12,16 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import inspect import json import pprint import re import textwrap import warnings +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass, field from datetime import datetime +from functools import wraps from itertools import islice from pathlib import Path -from typing import Any, BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union, overload from urllib.parse import quote import requests @@ -91,7 +96,7 @@ from .utils._deprecation import ( _deprecate_arguments, ) -from .utils._typing import Literal, TypedDict +from .utils._typing import CallableT, Literal, TypedDict from .utils.endpoint_helpers import ( AttributeDictionary, DatasetFilter, @@ -102,6 +107,8 @@ ) +R = TypeVar("R") # Return type + USERNAME_PLACEHOLDER = "hf_user" _REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$") @@ -788,6 +795,39 @@ class UserLikes: spaces: List[str] +def future_compatible(fn: CallableT) -> CallableT: + """Wrap a method of `HfApi` to handle `run_as_future=True`. + + A method flagged as "future_compatible" will be called in a thread if `run_as_future=True` and return a + `concurrent.futures.Future` instance. Otherwise, it will be called normally and return the result. + """ + sig = inspect.signature(fn) + args_params = list(sig.parameters)[1:] # remove "self" from list + + @wraps(fn) + def _inner(self, *args, **kwargs): + # Get `run_as_future` value if provided (default to False) + if "run_as_future" in kwargs: + run_as_future = kwargs["run_as_future"] + kwargs["run_as_future"] = False # avoid recursion error + else: + run_as_future = False + for param, value in zip(args_params, args): + if param == "run_as_future": + run_as_future = value + break + + # Call the function in a thread if `run_as_future=True` + if run_as_future: + return self.run_as_future(fn, self, *args, **kwargs) + + # Otherwise, call the function normally + return fn(self, *args, **kwargs) + + _inner.is_future_compatible = True # type: ignore + return _inner # type: ignore + + class HfApi: def __init__( self, @@ -827,7 +867,49 @@ def __init__( self.library_name = library_name self.library_version = library_version self.user_agent = user_agent + self._thread_pool: Optional[ThreadPoolExecutor] = None + + def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]: + """ + Run a method in the background and return a Future instance. + + The main goal is to run methods without blocking the main thread (e.g. to push data during a training). + Background jobs are queued to preserve order but are not ran in parallel. If you need to speed-up your scripts + by parallelizing lots of call to the API, you must setup and use your own [ThreadPoolExecutor](https://docs.python.org/3/library/concurrent.futures.html#threadpoolexecutor). + + Note: Most-used methods like [`upload_file`], [`upload_folder`] and [`create_commit`] have a `run_as_future: bool` + argument to directly call them in the background. This is equivalent to calling `api.run_as_future(...)` on them + but less verbose. + + Args: + fn (`Callable`): + The method to run in the background. + *args, **kwargs: + Arguments with which the method will be called. + Return: + [`Future`](https://docs.python.org/3/library/concurrent.futures.html#future-objects): a Future instance to + get the result of the task. + + Example: + ```py + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> future = api.run_as_future(api.whoami) # instant + >>> future.done() + False + >>> future.result() # wait until complete and return result + (...) + >>> future.done() + True + ``` + """ + if self._thread_pool is None: + self._thread_pool = ThreadPoolExecutor(max_workers=1) + self._thread_pool + return self._thread_pool.submit(fn, *args, **kwargs) + + @validate_hf_hub_args def whoami(self, token: Optional[str] = None) -> Dict: """ Call HF API to know "whoami". @@ -873,7 +955,9 @@ def _is_valid_token(self, token: str) -> bool: return False def get_model_tags(self) -> ModelTags: - "Gets all valid model tags as a nested namespace object" + """ + List all valid model tags as a nested namespace object + """ path = f"{self.endpoint}/api/models-tags-by-type" r = get_session().get(path) hf_raise_for_status(r) @@ -882,7 +966,7 @@ def get_model_tags(self) -> ModelTags: def get_dataset_tags(self) -> DatasetTags: """ - Gets all valid dataset tags as a nested namespace object. + List all valid dataset tags as a nested namespace object. """ path = f"{self.endpoint}/api/datasets-tags-by-type" r = get_session().get(path) @@ -2404,7 +2488,44 @@ def move_repo( ) raise + @overload + def create_commit( # type: ignore + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: Literal[False] = ..., + ) -> CommitInfo: + ... + + @overload + def create_commit( + self, + repo_id: str, + operations: Iterable[CommitOperation], + *, + commit_message: str, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + num_threads: int = 5, + parent_commit: Optional[str] = None, + run_as_future: Literal[True] = ..., + ) -> Future[CommitInfo]: + ... + @validate_hf_hub_args + @future_compatible def create_commit( self, repo_id: str, @@ -2418,7 +2539,8 @@ def create_commit( create_pr: Optional[bool] = None, num_threads: int = 5, parent_commit: Optional[str] = None, - ) -> CommitInfo: + run_as_future: bool = False, + ) -> Union[CommitInfo, Future[CommitInfo]]: """ Creates a commit in the given repo, deleting & uploading files as needed. @@ -2469,6 +2591,10 @@ def create_commit( is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. Returns: [`CommitInfo`]: @@ -2900,7 +3026,44 @@ def create_commits_on_pr( return pr.url + @overload + def upload_file( # type: ignore + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: Literal[False] = ..., + ) -> str: + ... + + @overload + def upload_file( + self, + *, + path_or_fileobj: Union[str, Path, bytes, BinaryIO], + path_in_repo: str, + repo_id: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + run_as_future: Literal[True] = ..., + ) -> Future[str]: + ... + @validate_hf_hub_args + @future_compatible def upload_file( self, *, @@ -2914,7 +3077,8 @@ def upload_file( commit_description: Optional[str] = None, create_pr: Optional[bool] = None, parent_commit: Optional[str] = None, - ) -> str: + run_as_future: bool = False, + ) -> Union[str, Future[str]]: """ Upload a local file (up to 50 GB) to the given repo. The upload is done through a HTTP post request, and doesn't require git or git-lfs to be @@ -2955,6 +3119,10 @@ def upload_file( If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be especially useful if the repo is updated / committed to concurrently. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. Returns: @@ -3049,7 +3217,54 @@ def upload_file( # Similar to `hf_hub_url` but it's "blob" instead of "resolve" return f"{self.endpoint}/{repo_id}/blob/{revision}/{path_in_repo}" + @overload + def upload_folder( # type: ignore + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + multi_commits: bool = False, + multi_commits_verbose: bool = False, + run_as_future: Literal[False] = ..., + ) -> str: + ... + + @overload + def upload_folder( + self, + *, + repo_id: str, + folder_path: Union[str, Path], + path_in_repo: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + token: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + multi_commits: bool = False, + multi_commits_verbose: bool = False, + run_as_future: Literal[True] = ..., + ) -> Future[str]: + ... + @validate_hf_hub_args + @future_compatible def upload_folder( self, *, @@ -3068,7 +3283,8 @@ def upload_folder( delete_patterns: Optional[Union[List[str], str]] = None, multi_commits: bool = False, multi_commits_verbose: bool = False, - ): + run_as_future: bool = False, + ) -> Union[str, Future[str]]: """ Upload a local folder to the given repo. The upload is done through a HTTP requests, and doesn't require git or git-lfs to be installed. @@ -3139,6 +3355,10 @@ def upload_folder( If True, changes are pushed to a PR using a multi-commit process. Defaults to `False`. multi_commits_verbose (`bool`): If True and `multi_commits` is used, more information will be displayed to the user. + run_as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `run_as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`. Returns: `str`: A URL to visualize the uploaded folder on the hub @@ -4903,6 +5123,9 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: delete_tag = api.delete_tag get_full_repo_name = api.get_full_repo_name +# Background jobs +run_as_future = api.run_as_future + # Activity API list_liked_repos = api.list_liked_repos like = api.like diff --git a/src/huggingface_hub/utils/_typing.py b/src/huggingface_hub/utils/_typing.py index 812c65ea39..c8885eb1eb 100644 --- a/src/huggingface_hub/utils/_typing.py +++ b/src/huggingface_hub/utils/_typing.py @@ -14,6 +14,7 @@ # limitations under the License. """Handle typing imports based on system compatibility.""" import sys +from typing import Callable, TypeVar if sys.version_info >= (3, 8): @@ -22,3 +23,6 @@ from typing_extensions import Literal, TypedDict # noqa: F401 HTTP_METHOD_T = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"] + +# type hint meaning "function signature not changed by decorator" +CallableT = TypeVar("CallableT", bound=Callable) diff --git a/src/huggingface_hub/utils/_validators.py b/src/huggingface_hub/utils/_validators.py index 5ec3de775f..5dd64fa514 100644 --- a/src/huggingface_hub/utils/_validators.py +++ b/src/huggingface_hub/utils/_validators.py @@ -18,7 +18,9 @@ import warnings from functools import wraps from itertools import chain -from typing import Any, Callable, Dict, TypeVar +from typing import Any, Dict + +from ._typing import CallableT REPO_ID_REGEX = re.compile( @@ -41,10 +43,6 @@ class HFValidationError(ValueError): """ -# type hint meaning "function signature not changed by decorator" -CallableT = TypeVar("CallableT", bound=Callable) - - def validate_hf_hub_args(fn: CallableT) -> CallableT: """Validate values received as argument for any public method of `huggingface_hub`. diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index a2454ef833..695f0b108c 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -20,6 +20,7 @@ import types import unittest import warnings +from concurrent.futures import Future from functools import partial from io import BytesIO from pathlib import Path @@ -2384,6 +2385,71 @@ def test_pause_and_restart_space(self) -> None: self.assertIn(runtime_after_restart.stage, (SpaceStage.BUILDING, SpaceStage.RUNNING_BUILDING)) +@pytest.mark.usefixtures("fx_cache_dir") +class TestCommitInBackground(HfApiCommonTest): + cache_dir: Path + + @use_tmp_repo() + def test_commit_to_repo_in_background(self, repo_url: RepoUrl) -> None: + repo_id = repo_url.repo_id + (self.cache_dir / "file.txt").write_text("content") + (self.cache_dir / "lfs.bin").write_text("content") + + t0 = time.time() + upload_future_1 = self._api.upload_file( + path_or_fileobj=b"1", path_in_repo="1.txt", repo_id=repo_id, commit_message="Upload 1", run_as_future=True + ) + upload_future_2 = self._api.upload_file( + path_or_fileobj=b"2", path_in_repo="2.txt", repo_id=repo_id, commit_message="Upload 2", run_as_future=True + ) + upload_future_3 = self._api.upload_folder( + repo_id=repo_id, folder_path=self.cache_dir, commit_message="Upload folder", run_as_future=True + ) + t1 = time.time() + + # all futures are queued instantly + self.assertLessEqual(t1 - t0, 0.01) + + # wait for the last job to complete + upload_future_3.result() + + # all of them are now complete (ran in order) + self.assertTrue(upload_future_1.done()) + self.assertTrue(upload_future_2.done()) + self.assertTrue(upload_future_3.done()) + + # 4 commits, sorted in reverse order of creation + commits = self._api.list_repo_commits(repo_id=repo_id) + self.assertEqual(len(commits), 4) + self.assertEqual(commits[0].title, "Upload folder") + self.assertEqual(commits[1].title, "Upload 2") + self.assertEqual(commits[2].title, "Upload 1") + self.assertEqual(commits[3].title, "initial commit") + + @use_tmp_repo() + def test_run_as_future(self, repo_url: RepoUrl) -> None: + repo_id = repo_url.repo_id + self._api.run_as_future(self._api.like, repo_id) + future_1 = self._api.run_as_future(self._api.model_info, repo_id=repo_id) + self._api.run_as_future(self._api.unlike, repo_id) + future_2 = self._api.run_as_future(self._api.model_info, repo_id=repo_id) + + self.assertIsInstance(future_1, Future) + self.assertIsInstance(future_2, Future) + + # Wait for first info future + info_1 = future_1.result() + self.assertFalse(future_2.done()) + + # Wait for second info future + info_2 = future_2.result() + self.assertTrue(future_2.done()) + + # Like/unlike is correct + self.assertEqual(info_1.likes, 1) + self.assertEqual(info_2.likes, 0) + + class TestSpaceAPIMocked(unittest.TestCase): """ Testing Space hardware requests is resource intensive for the server (need to spawn diff --git a/tests/test_init_lazy_loading.py b/tests/test_init_lazy_loading.py index 8faa320673..9312543128 100644 --- a/tests/test_init_lazy_loading.py +++ b/tests/test_init_lazy_loading.py @@ -27,7 +27,7 @@ def test_autocomplete_on_root_imports(self) -> None: # Assert docstring is find. This means autocomplete can also provide # the help section. signature_list = goto_list[0].get_signatures() - self.assertEqual(len(signature_list), 1) + self.assertEqual(len(signature_list), 2) # create_commit has 2 signatures (normal and `run_as_future`) self.assertTrue(signature_list[0].docstring().startswith("create_commit(repo_id: str,")) break else: diff --git a/utils/_legacy_check_future_compatible_signatures.py b/utils/_legacy_check_future_compatible_signatures.py new file mode 100644 index 0000000000..bbafc93598 --- /dev/null +++ b/utils/_legacy_check_future_compatible_signatures.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a tool to add/check the definition of "async" methods of `HfApi` in `huggingface_hub.hf_api.py`. + +WARNING: this is a script kept to help with `@future_compatible` methods of `HfApi` but it is not 100% correct. +Keeping it here for reference but it is not used in the CI/Makefile. + +What is done correctly: +1. Add "as_future" as argument to the method signature +2. Set Union[T, Future[T]] as return type to the method signature +3. Document "as_future" argument in the docstring of the method + +What is NOT done correctly: +1. Generated stubs are grouped at the top of the `HfApi` class. They must be copy-pasted (overload definition must be +just before the method implementation) +2. `#type: ignore` must be adjusted in the first stub (if multiline definition) +""" +import argparse +import inspect +import os +import re +import tempfile +from pathlib import Path +from typing import Callable, NoReturn + +import black +from ruff.__main__ import find_ruff_bin + +from huggingface_hub.hf_api import HfApi + + +STUBS_SECTION_TEMPLATE = """ + ### Stubs section start ### + + # This section contains stubs for the methods that are marked as `@future_compatible`. Those methods have a + # different return type depending on the `as_future: bool` value. For better integrations with IDEs, we provide + # stubs for both return types. The actual implementation of those methods is written below. + + # WARNING: this section have been generated automatically. Do not modify it manually. If you modify it manually, your + # changes will be overwritten. To re-generate this section, run `make style` (or `python utils/check_future_compatible_signatures.py` + # directly). + + # FAQ: + # 1. Why should we have these? For better type annotation which helps with IDE features like autocompletion. + # 2. Why not a separate `hf_api.pyi` file? Would require to re-defined all the existing annotations from `hf_api.py`. + # 3. Why not at the end of the module? Because `@overload` methods must be defined first. + # 4. Why not another solution? I'd be glad, but this is the "less worse" I could find. + # For more details, see https://github.com/huggingface/huggingface_hub/pull/1458 + + + {stubs} + + # WARNING: this section have been generated automatically. Do not modify it manually. If you modify it manually, your + # changes will be overwritten. To re-generate this section, run `make style` (or `python utils/check_future_compatible_signatures.py` + # directly). + + ### Stubs section end ### +""" + +STUBS_SECTION_TEMPLATE_REGEX = re.compile(r"### Stubs section start ###.*### Stubs section end ###", re.DOTALL) + +AS_FUTURE_SIGNATURE_TEMPLATE = "as_future: bool = False" + +AS_FUTURE_DOCSTRING_TEMPLATE = """ + as_future (`bool`, *optional*): + Whether or not to run this method in the background. Background jobs are run sequentially without + blocking the main thread. Passing `as_future=True` will return a [Future](https://docs.python.org/3/library/concurrent.futures.html#future-objects) + object. Defaults to `False`.""" + +ARGS_DOCSTRING_REGEX = re.compile( + """ +^[ ]{8}Args: # Match args section ... +(.*?) # ... everything ... +^[ ]{8}\\S # ... until next section or end of docstring +""", + re.MULTILINE | re.IGNORECASE | re.VERBOSE | re.DOTALL, +) + +SIGNATURE_REGEX_FULL = re.compile(r"^\s*def.*?-> (.*?):", re.DOTALL | re.MULTILINE) +SIGNATURE_REGEX_RETURN_TYPE = re.compile(r"-> (.*?):") +SIGNATURE_REGEX_RETURN_TYPE_WITH_FUTURE = re.compile(r"-> Union\[(.*?), (.*?)\]:") + + +HF_API_FILE_PATH = Path(__file__).parents[1] / "src" / "huggingface_hub" / "hf_api.py" +HF_API_FILE_CONTENT = HF_API_FILE_PATH.read_text() + + +def generate_future_compatible_method(method: Callable, method_source: str) -> str: + # 1. Document `as_future` parameter + if AS_FUTURE_DOCSTRING_TEMPLATE not in method_source: + match = ARGS_DOCSTRING_REGEX.search(method_source) + if match is None: + raise ValueError(f"Could not find `Args` section in docstring of {method}.") + args_docs = match.group(1).strip() + method_source = method_source.replace(args_docs, args_docs + AS_FUTURE_DOCSTRING_TEMPLATE) + + # 2. Update signature + # 2.a. Add `as_future` parameter + if AS_FUTURE_SIGNATURE_TEMPLATE not in method_source: + match = SIGNATURE_REGEX_FULL.search(method_source) + if match is None: + raise ValueError(f"Could not find signature of {method} in source.") + method_source = method_source.replace( + match.group(), match.group().replace(") ->", f" {AS_FUTURE_SIGNATURE_TEMPLATE}) ->"), 1 + ) + + # 2.b. Update return value + if "Future[" not in method_source: + match = SIGNATURE_REGEX_RETURN_TYPE.search(method_source) + if match is None: + raise ValueError(f"Could not find return type of {method} in source.") + base_type = match.group(1).strip() + return_type = f"Union[{base_type}, Future[{base_type}]]" + return_value_replaced = match.group().replace(match.group(1), return_type) + method_source = method_source.replace(match.group(), return_value_replaced) + + # 3. Generate @overload stubs + match = SIGNATURE_REGEX_FULL.search(method_source) + if match is None: + raise ValueError(f"Could not find signature of {method} in source.") + method_sig = match.group() + + match = SIGNATURE_REGEX_RETURN_TYPE_WITH_FUTURE.search(method_sig) + if match is None: + raise ValueError(f"Could not find return type (with Future) of {method} in source.") + no_future_return_type = match.group(1).strip() + with_future_return_type = match.group(2).strip() + + # 3.a. Stub when `as_future=False` + no_future_stub = " @overload\n" + method_sig + no_future_stub = no_future_stub.replace(AS_FUTURE_SIGNATURE_TEMPLATE, "as_future: Literal[False] = ...") + no_future_stub = SIGNATURE_REGEX_RETURN_TYPE.sub(rf"-> {no_future_return_type}:", no_future_stub) + no_future_stub += " # type: ignore\n ..." # only the first stub requires "type: ignore" + + # 3.b. Stub when `as_future=True` + with_future_stub = " @overload\n" + method_sig + with_future_stub = with_future_stub.replace(AS_FUTURE_SIGNATURE_TEMPLATE, "as_future: Literal[True] = ...") + with_future_stub = SIGNATURE_REGEX_RETURN_TYPE.sub(rf"-> {with_future_return_type}:", with_future_stub) + with_future_stub += "\n ..." + + stubs_source = no_future_stub + "\n\n" + with_future_stub + "\n\n" + + # 4. All good! + return method_source, stubs_source + + +def generate_hf_api_module() -> str: + raw_code = HF_API_FILE_CONTENT + + # Process all Future-compatible methods + all_stubs_source = "" + for _, method in inspect.getmembers(HfApi, predicate=inspect.isfunction): + if not getattr(method, "is_future_compatible", False): + continue + source = inspect.getsource(method) + method_source, stubs_source = generate_future_compatible_method(method, source) + + raw_code = raw_code.replace(source, method_source) + all_stubs_source += "\n\n" + stubs_source + + # Generate code with stubs + generated_code = STUBS_SECTION_TEMPLATE_REGEX.sub(STUBS_SECTION_TEMPLATE.format(stubs=all_stubs_source), raw_code) + + # Format (black+ruff) + return format_generated_code(generated_code) + + +def format_generated_code(code: str) -> str: + """ + Format some code with black+ruff. Cannot be done "on the fly" so we first save the code in a temporary file. + """ + # Format with black + code = black.format_file_contents(code, fast=False, mode=black.FileMode(line_length=119)) + + # Format with ruff + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "__init__.py" + filepath.write_text(code) + ruff_bin = find_ruff_bin() + os.spawnv(os.P_WAIT, ruff_bin, ["ruff", str(filepath), "--fix", "--quiet"]) + return filepath.read_text() + + +def check_future_compatible_hf_api(update: bool) -> NoReturn: + """Check that the code defining the threaded version of HfApi is up-to-date.""" + # If expected `__init__.py` content is different, test fails. If '--update-init-file' + # is used, `__init__.py` file is updated before the test fails. + expected_content = generate_hf_api_module() + if expected_content != HF_API_FILE_CONTENT: + if update: + with HF_API_FILE_PATH.open("w") as f: + f.write(expected_content) + + print( + "✅ Signature/docstring/annotations for Future-compatible methods have been updated in" + " `./src/huggingface_hub/hf_api.py`.\n Please make sure the changes are accurate and commit them." + ) + exit(0) + else: + print( + "❌ Expected content mismatch for Future compatible methods in `./src/huggingface_hub/hf_api.py`.\n " + " Please run `make style` or `python utils/check_future_compatible_signatures.py --update`." + ) + exit(1) + + print("✅ All good! (Future-compatible methods)") + exit(0) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--update", + action="store_true", + help="Whether to override `./src/huggingface_hub/hf_api.py` if a change is detected.", + ) + args = parser.parse_args() + + check_future_compatible_hf_api(update=args.update)