From 6fa66c49d901fd795e8c8925f9a1ae770739cfc8 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 23 Mar 2022 15:15:03 +0100 Subject: [PATCH 01/47] light typing (cherry picked from commit b2c8f9b970c505cdf2c685e645e9e36cc472b0d3) --- src/huggingface_hub/hf_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index e5ad61b837..626a23cc2c 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -89,7 +89,9 @@ def _validate_repo_id_deprecation(repo_id, name, organization): return name, organization -def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None): +def repo_type_and_id_from_hf_id( + hf_id: str, hub_url: Optional[str] = None +) -> Tuple[Optional[str], Optional[str], str]: """ Returns the repo type and ID from a huggingface.co URL linking to a repository From 29a8833a572d2761830376085277316d3afa9dbb Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 18:25:44 +0100 Subject: [PATCH 02/47] remove this seminal comment (cherry picked from commit 12a841a605c94733154f3b22e812c0f5e69ef37b) --- src/huggingface_hub/_snapshot_download.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 5cd0a5fad5..731e629076 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -97,9 +97,6 @@ def snapshot_download( """ - # Note: at some point maybe this format of storage should actually replace - # the flat storage structure we've used so far (initially from allennlp - # if I remember correctly). if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE From da92022541518e92509621435f5f88de3ca15ff8 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 18:29:50 +0100 Subject: [PATCH 03/47] I don't understand why we don't early return here cc @patrickvonplaten care to take a look? cc @LysandreJik (cherry picked from commit 259ab36f03ab3eed6eeb4fc4984bc259619b442f) --- src/huggingface_hub/_snapshot_download.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 731e629076..a8de9da17e 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -181,12 +181,10 @@ def snapshot_download( " behavior by not taking into account the latest commits." ) - # find last modified folder + # find last modified folder, and return it storage_folder = max(repo_folders, key=os.path.getmtime) - # get commit sha - repo_id_sha = storage_folder.split(".")[-1] - model_files = os.listdir(storage_folder) + return storage_folder else: # if we have internet connection we retrieve the correct folder name from the huggingface api _api = HfApi() From e4d36861217809d52ee5ff93e8d495df3d796ae7 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 18:30:49 +0100 Subject: [PATCH 04/47] following last commit, unnest this (cherry picked from commit 54957f3f049d887af21dd8f6950873a2823c4247) --- src/huggingface_hub/_snapshot_download.py | 26 +++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index a8de9da17e..be1595939b 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -185,22 +185,22 @@ def snapshot_download( storage_folder = max(repo_folders, key=os.path.getmtime) return storage_folder - else: - # if we have internet connection we retrieve the correct folder name from the huggingface api - _api = HfApi() - model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token) - storage_folder = os.path.join(cache_dir, repo_id_flattened + "." + revision) + # if we have internet connection we retrieve the correct folder name from the huggingface api + _api = HfApi() + model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token) + + storage_folder = os.path.join(cache_dir, repo_id_flattened + "." + revision) - # if passed revision is not identical to the commit sha - # then revision has to be a branch name, e.g. "main" - # in this case make sure that the branch name is included - # cached storage folder name - if revision != model_info.sha: - storage_folder += f".{model_info.sha}" + # if passed revision is not identical to the commit sha + # then revision has to be a branch name, e.g. "main" + # in this case make sure that the branch name is included + # cached storage folder name + if revision != model_info.sha: + storage_folder += f".{model_info.sha}" - repo_id_sha = model_info.sha - model_files = [f.rfilename for f in model_info.siblings] + repo_id_sha = model_info.sha + model_files = [f.rfilename for f in model_info.siblings] allow_regex = [allow_regex] if isinstance(allow_regex, str) else allow_regex ignore_regex = [ignore_regex] if isinstance(ignore_regex, str) else ignore_regex From 8e4ccf45dc432b30843e8c5026ba663933cf416b Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 18:35:22 +0100 Subject: [PATCH 05/47] [BIG] This should work for all repo_types not just models! (cherry picked from commit 9a3f96ccb2de6663cf4cf2d9a60dd7f415227c1b) --- src/huggingface_hub/_snapshot_download.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index be1595939b..42ae8d99e3 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -24,6 +24,7 @@ def snapshot_download( repo_id: str, *, revision: Optional[str] = None, + repo_type: Optional[str] = None, cache_dir: Union[str, Path, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, @@ -188,7 +189,9 @@ def snapshot_download( # if we have internet connection we retrieve the correct folder name from the huggingface api _api = HfApi() - model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token) + repo_info = _api.repo_info( + repo_id=repo_id, repo_type=repo_type, revision=revision, token=token + ) storage_folder = os.path.join(cache_dir, repo_id_flattened + "." + revision) @@ -196,30 +199,30 @@ def snapshot_download( # then revision has to be a branch name, e.g. "main" # in this case make sure that the branch name is included # cached storage folder name - if revision != model_info.sha: - storage_folder += f".{model_info.sha}" + if revision != repo_info.sha: + storage_folder += f".{repo_info.sha}" - repo_id_sha = model_info.sha - model_files = [f.rfilename for f in model_info.siblings] + repo_id_sha = repo_info.sha + repo_files = [f.rfilename for f in repo_info.siblings] allow_regex = [allow_regex] if isinstance(allow_regex, str) else allow_regex ignore_regex = [ignore_regex] if isinstance(ignore_regex, str) else ignore_regex - for model_file in model_files: + for repo_file in repo_files: # if there's an allowlist, skip download if file does not match any regex if allow_regex is not None and not any( - fnmatch(model_file, r) for r in allow_regex + fnmatch(repo_file, r) for r in allow_regex ): continue # if there's a denylist, skip download if file does matches any regex if ignore_regex is not None and any( - fnmatch(model_file, r) for r in ignore_regex + fnmatch(repo_file, r) for r in ignore_regex ): continue - url = hf_hub_url(repo_id, filename=model_file, revision=repo_id_sha) - relative_filepath = os.path.join(*model_file.split("/")) + url = hf_hub_url(repo_id, filename=repo_file, revision=repo_id_sha) + relative_filepath = os.path.join(*repo_file.split("/")) # Create potential nested dir nested_dirname = os.path.dirname( From 5ab8b7429a38e6b5b625042a4c58f96f63681e76 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 18:44:05 +0100 Subject: [PATCH 06/47] one more (cherry picked from commit b74871250616c44a2125b26d5de29b1189e82e12) --- src/huggingface_hub/_snapshot_download.py | 132 ++++++++++++---------- 1 file changed, 72 insertions(+), 60 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 42ae8d99e3..66601deebc 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union -from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE +from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE, REPO_TYPES from .file_download import cached_download, hf_hub_url from .hf_api import HfApi, HfFolder from .utils import logging @@ -39,64 +39,71 @@ def snapshot_download( ) -> str: """Download all files of a repo. - Downloads a whole snapshot of a repo's files at the specified revision. This - is useful when you want all files from a repo, because you don't know which - ones you will need a priori. All files are nested inside a folder in order - to keep their actual filename relative to that folder. - - An alternative would be to just clone a repo but this would require that the - user always has git and git-lfs installed, and properly configured. - - Args: - repo_id (`str`): - A user or an organization name and a repo name separated by a `/`. - revision (`str`, *optional*): - An optional Git revision id which can be a branch name, a tag, or a - commit hash. - cache_dir (`str`, `Path`, *optional*): - Path to the folder where cached files are stored. - library_name (`str`, *optional*): - The name of the library to which the object corresponds. - library_version (`str`, *optional*): - The version of the library. - user_agent (`str`, `dict`, *optional*): - The user-agent info in the form of a dictionary or a string. - proxies (`dict`, *optional*): - Dictionary mapping protocol to the URL of the proxy passed to - `requests.request`. - etag_timeout (`float`, *optional*, defaults to `10`): - When fetching ETag, how many seconds to wait for the server to send - data before giving up which is passed to `requests.request`. - resume_download (`bool`, *optional*, defaults to `False): - If `True`, resume a previously interrupted download. - use_auth_token (`str`, `bool`, *optional*): - A token to be used for the download. - - If `True`, the token is read from the HuggingFace config - folder. - - If a string, it's used as the authentication token. - local_files_only (`bool`, *optional*, defaults to `False`): - If `True`, avoid downloading the file and return the path to the - local cached file if it exists. - allow_regex (`list of str`, `str`, *optional*): - If provided, only files matching this regex are downloaded. - ignore_regex (`list of str`, `str`, *optional*): - If provided, files matching this regex are not downloaded. - - Returns: - Local folder path (string) of repo snapshot - - - - Raises the following errors: - - - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) - if `use_auth_token=True` and the token cannot be found. - - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if - ETag cannot be determined. - - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) - if some parameter value is invalid - - + Downloads a whole snapshot of a repo's files at the specified revision. This + is useful when you want all files from a repo, because you don't know which + ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. + + An alternative would be to just clone a repo but this would require that the + user always has git and git-lfs installed, and properly configured. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + <<<<<<< HEAD + cache_dir (`str`, `Path`, *optional*): + ======= + repo_type: Set to :obj:`"dataset"` or :obj:`"space"` if downloading + a dataset or space, :obj:`None` or :obj:`"model"` if + downloading a model. Default is :obj:`None`. + cache_dir (``str``, ``Path``, `optional`): + >>>>>>> b748712 (one more) + Path to the folder where cached files are stored. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + user_agent (`str`, `dict`, *optional*): + The user-agent info in the form of a dictionary or a string. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + resume_download (`bool`, *optional*, defaults to `False): + If `True`, resume a previously interrupted download. + use_auth_token (`str`, `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If a string, it's used as the authentication token. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + allow_regex (`list of str`, `str`, *optional*): + If provided, only files matching this regex are downloaded. + ignore_regex (`list of str`, `str`, *optional*): + If provided, files matching this regex are not downloaded. + + Returns: + Local folder path (string) of repo snapshot + + + + Raises the following errors: + + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `use_auth_token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + + """ if cache_dir is None: @@ -118,6 +125,11 @@ def snapshot_download( else: token = None + if repo_type is None: + repo_type = "model" + if repo_type not in REPO_TYPES: + raise ValueError("Invalid repo type") + # remove all `/` occurrences to correctly convert repo to directory name repo_id_flattened = repo_id.replace("/", REPO_ID_SEPARATOR) @@ -152,7 +164,7 @@ def snapshot_download( if len(repo_folders) == 0: raise ValueError( "Cannot find the requested files in the cached path and outgoing" - " traffic has been disabled. To enable model look-ups and downloads" + " traffic has been disabled. To enable repo look-ups and downloads" " online, set 'local_files_only' to False." ) From b8376f847c6b4cff17cef72321465449b6c255a1 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 18:46:39 +0100 Subject: [PATCH 07/47] forgot a repo_type and reorder code (cherry picked from commit 3ef7d79a44087e971e10e35d3b9f5bea3474f297) --- src/huggingface_hub/_snapshot_download.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 66601deebc..5d42583782 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -233,7 +233,6 @@ def snapshot_download( ): continue - url = hf_hub_url(repo_id, filename=repo_file, revision=repo_id_sha) relative_filepath = os.path.join(*repo_file.split("/")) # Create potential nested dir @@ -242,6 +241,10 @@ def snapshot_download( ) os.makedirs(nested_dirname, exist_ok=True) + url = hf_hub_url( + repo_id, filename=repo_file, repo_type=repo_type, revision=repo_id_sha + ) + path = cached_download( url, cache_dir=storage_folder, From 4cb1d631bff500383fe850ad9ebd233426c0e5a4 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 18:46:55 +0100 Subject: [PATCH 08/47] also rename this cache folder (cherry picked from commit 4c518b861723a6d28d59108403c37edf5208f2fe) --- src/huggingface_hub/_snapshot_download.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 5d42583782..b42de41768 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -131,7 +131,9 @@ def snapshot_download( raise ValueError("Invalid repo type") # remove all `/` occurrences to correctly convert repo to directory name - repo_id_flattened = repo_id.replace("/", REPO_ID_SEPARATOR) + repo_id_flattened = f"{repo_type}s{REPO_ID_SEPARATOR}" + repo_id.replace( + "/", REPO_ID_SEPARATOR + ) # if we have no internet connection we will look for the # last modified folder in the cache From f7cbe1831960d35a0e65632d960cb35e457ec115 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 18:49:04 +0100 Subject: [PATCH 09/47] Use `hf_hub_download`, will be simpler later (cherry picked from commit c7478d58fe62da02625b8ca17796ad1419a048b1) --- src/huggingface_hub/_snapshot_download.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index b42de41768..b0776beeaf 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Union from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE, REPO_TYPES -from .file_download import cached_download, hf_hub_url +from .file_download import cached_download, hf_hub_download, hf_hub_url from .hf_api import HfApi, HfFolder from .utils import logging from .utils._deprecation import _deprecate_positional_args @@ -207,14 +207,14 @@ def snapshot_download( repo_id=repo_id, repo_type=repo_type, revision=revision, token=token ) - storage_folder = os.path.join(cache_dir, repo_id_flattened + "." + revision) + storage_folder = os.path.join(cache_dir, repo_id_flattened) # if passed revision is not identical to the commit sha # then revision has to be a branch name, e.g. "main" # in this case make sure that the branch name is included # cached storage folder name - if revision != repo_info.sha: - storage_folder += f".{repo_info.sha}" + # if revision != repo_info.sha: + # storage_folder += f".{repo_info.sha}" repo_id_sha = repo_info.sha repo_files = [f.rfilename for f in repo_info.siblings] @@ -243,12 +243,11 @@ def snapshot_download( ) os.makedirs(nested_dirname, exist_ok=True) - url = hf_hub_url( - repo_id, filename=repo_file, repo_type=repo_type, revision=repo_id_sha - ) - - path = cached_download( - url, + path = hf_hub_download( + repo_id, + filename=repo_file, + repo_type=repo_type, + revision=repo_id_sha, cache_dir=storage_folder, force_filename=relative_filepath, library_name=library_name, From ea94d43dc91efc500b3cd66a7ec3172e4cbb0dc7 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 21:11:16 +0100 Subject: [PATCH 10/47] in this new version, `force_filename` does not make sense anymore (cherry picked from commit 9a674bc795d5c8a26aecf5429d391fff92e47e8d) --- src/huggingface_hub/file_download.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 1720036084..c296ee4a1d 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -449,7 +449,6 @@ def cached_download( cache_dir: Union[str, Path, None] = None, user_agent: Union[Dict, str, None] = None, force_download: Optional[bool] = False, - force_filename: Optional[str] = None, proxies: Optional[Dict] = None, etag_timeout: Optional[float] = 10, resume_download: Optional[bool] = False, @@ -478,8 +477,6 @@ def cached_download( force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. - force_filename (`str`, *optional*): - Use this name instead of a generated file name. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. @@ -579,9 +576,7 @@ def cached_download( # etag is None pass - filename = ( - force_filename if force_filename is not None else url_to_filename(url, etag) - ) + filename = url_to_filename(url, etag) # get cache path to put the file cache_path = os.path.join(cache_dir, filename) @@ -599,11 +594,7 @@ def cached_download( ) if not file.endswith(".json") and not file.endswith(".lock") ] - if ( - len(matching_files) > 0 - and not force_download - and force_filename is None - ): + if len(matching_files) > 0 and not force_download: return os.path.join(cache_dir, matching_files[-1]) else: # If files cannot be found and local_files_only=True, @@ -679,12 +670,11 @@ def _resumable_file_manager() -> "io.BufferedWriter": logger.info("storing %s in cache at %s", url, cache_path) os.replace(temp_file.name, cache_path) - if force_filename is None: - logger.info("creating metadata file for %s", cache_path) - meta = {"url": url, "etag": etag} - meta_path = cache_path + ".json" - with open(meta_path, "w") as meta_file: - json.dump(meta, meta_file) + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) return cache_path From 7b92719a05983f61a9797941b65c2cd4dc4cc99b Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 21:12:58 +0100 Subject: [PATCH 11/47] Just inline everything inside `hf_hub_download` for now (cherry picked from commit ee49f8f57ba4e7e66f237df8f64c804862fe3ee8) --- src/huggingface_hub/file_download.py | 177 ++++++++++++++++++++++++--- 1 file changed, 163 insertions(+), 14 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index c296ee4a1d..91912eda7c 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -765,17 +765,166 @@ def hf_hub_download( repo_id, filename, subfolder=subfolder, repo_type=repo_type, revision=revision ) - return cached_download( - url, - library_name=library_name, - library_version=library_version, - cache_dir=cache_dir, - user_agent=user_agent, - force_download=force_download, - force_filename=force_filename, - proxies=proxies, - etag_timeout=etag_timeout, - resume_download=resume_download, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - ) + if cache_dir is None: + cache_dir = HUGGINGFACE_HUB_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + os.makedirs(cache_dir, exist_ok=True) + + headers = { + "user-agent": http_user_agent( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + ) + } + if isinstance(use_auth_token, str): + headers["authorization"] = f"Bearer {use_auth_token}" + elif use_auth_token: + token = HfFolder.get_token() + if token is None: + raise EnvironmentError( + "You specified use_auth_token=True, but a huggingface token was not found." + ) + headers["authorization"] = f"Bearer {token}" + + url_to_download = url + etag = None + if not local_files_only: + try: + r = _request_with_retry( + method="HEAD", + url=url, + headers=headers, + allow_redirects=False, + proxies=proxies, + timeout=etag_timeout, + ) + r.raise_for_status() + etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") + # We favor a custom header indicating the etag of the linked resource, and + # we fallback to the regular etag header. + # If we don't have any of those, raise an error. + if etag is None: + raise OSError( + "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." + ) + # In case of a redirect, + # save an extra redirect on the request.get call, + # and ensure we download the exact atomic version even if it changed + # between the HEAD and the GET (unlikely, but hey). + if 300 <= r.status_code <= 399: + url_to_download = r.headers["Location"] + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + OfflineModeIsEnabled, + ): + # Otherwise, our Internet connection is down. + # etag is None + pass + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # etag is None == we don't have a connection or we passed local_files_only. + # try to get the last downloaded one + if etag is None: + if os.path.exists(cache_path) and not force_download: + return cache_path + else: + matching_files = [ + file + for file in fnmatch.filter( + os.listdir(cache_dir), filename.split(".")[0] + ".*" + ) + if not file.endswith(".json") and not file.endswith(".lock") + ] + if len(matching_files) > 0 and not force_download: + return os.path.join(cache_dir, matching_files[-1]) + else: + # If files cannot be found and local_files_only=True, + # the models might've been found if local_files_only=False + # Notify the user about that + if local_files_only: + raise ValueError( + "Cannot find the requested files in the cached path and outgoing traffic has been" + " disabled. To enable model look-ups and downloads online, set 'local_files_only'" + " to False." + ) + else: + raise ValueError( + "Connection error, and we cannot find the requested files in the cached path." + " Please try again or make sure your Internet connection is on." + ) + + # From now on, etag is not None. + if os.path.exists(cache_path) and not force_download: + return cache_path + + # Prevent parallel downloads of the same file with a lock. + lock_path = cache_path + ".lock" + + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it is an extended path by using the "\\?\" prefix. + if os.name == "nt" and len(os.path.abspath(lock_path)) > 255: + lock_path = "\\\\?\\" + os.path.abspath(lock_path) + + if os.name == "nt" and len(os.path.abspath(cache_path)) > 255: + cache_path = "\\\\?\\" + os.path.abspath(cache_path) + + with FileLock(lock_path): + + # If the download just completed while the lock was activated. + if os.path.exists(cache_path) and not force_download: + # Even if returning early like here, the lock will be released. + return cache_path + + if resume_download: + incomplete_path = cache_path + ".incomplete" + + @contextmanager + def _resumable_file_manager() -> "io.BufferedWriter": + with open(incomplete_path, "ab") as f: + yield f + + temp_file_manager = _resumable_file_manager + if os.path.exists(incomplete_path): + resume_size = os.stat(incomplete_path).st_size + else: + resume_size = 0 + else: + temp_file_manager = partial( + tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False + ) + resume_size = 0 + + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info("downloading %s to %s", url, temp_file.name) + + http_get( + url_to_download, + temp_file, + proxies=proxies, + resume_size=resume_size, + headers=headers, + ) + + logger.info("storing %s in cache at %s", url, cache_path) + os.replace(temp_file.name, cache_path) + + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) + + return cache_path From d29c90ef0c18e60061e0777f2f5ea4cf69a8c530 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 25 Mar 2022 22:21:11 +0100 Subject: [PATCH 12/47] Big prototype! it works! :tada: (cherry picked from commit 7fe19ec66a2c5a7386a956cb9b65616cb209608a) --- example.py | 21 ++++ src/huggingface_hub/_snapshot_download.py | 26 ++--- src/huggingface_hub/constants.py | 7 ++ src/huggingface_hub/file_download.py | 125 +++++++++++++++------- 4 files changed, 122 insertions(+), 57 deletions(-) create mode 100644 example.py diff --git a/example.py b/example.py new file mode 100644 index 0000000000..6dcdbd9ae4 --- /dev/null +++ b/example.py @@ -0,0 +1,21 @@ +import torch +from huggingface_hub.file_download import hf_hub_download + +OLDER_REVISION = "bbc77c8132af1cc5cf678da3f1ddf2de43606d48" + +hf_hub_download("julien-c/EsperBERTo-small", filename="README.md") + +hf_hub_download("julien-c/EsperBERTo-small", filename="pytorch_model.bin") + +hf_hub_download( + "julien-c/EsperBERTo-small", filename="README.md", revision=OLDER_REVISION +) + +weights_file = hf_hub_download( + "julien-c/EsperBERTo-small", filename="pytorch_model.bin", revision=OLDER_REVISION +) + +w = torch.load(weights_file, map_location=torch.device("cpu")) +### Yay it works! + +print() diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index b0776beeaf..610abea88a 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -4,18 +4,18 @@ from pathlib import Path from typing import Dict, List, Optional, Union -from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE, REPO_TYPES +from .constants import ( + DEFAULT_REVISION, + HUGGINGFACE_HUB_CACHE, + REPO_ID_SEPARATOR, + REPO_TYPES, +) from .file_download import cached_download, hf_hub_download, hf_hub_url from .hf_api import HfApi, HfFolder from .utils import logging from .utils._deprecation import _deprecate_positional_args -REPO_ID_SEPARATOR = "--" -# ^ this substring is not allowed in repo_ids on hf.co -# and is the canonical one we use for serialization of repo ids elsewhere. - - logger = logging.get_logger(__name__) @@ -235,21 +235,12 @@ def snapshot_download( ): continue - relative_filepath = os.path.join(*repo_file.split("/")) - - # Create potential nested dir - nested_dirname = os.path.dirname( - os.path.join(storage_folder, relative_filepath) - ) - os.makedirs(nested_dirname, exist_ok=True) - - path = hf_hub_download( + _ = hf_hub_download( repo_id, filename=repo_file, repo_type=repo_type, revision=repo_id_sha, cache_dir=storage_folder, - force_filename=relative_filepath, library_name=library_name, library_version=library_version, user_agent=user_agent, @@ -260,7 +251,4 @@ def snapshot_download( local_files_only=local_files_only, ) - if os.path.exists(path + ".lock"): - os.remove(path + ".lock") - return storage_folder diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index 91f930d853..728f45ac54 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -29,6 +29,13 @@ ) HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" +HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit" +HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag" + +REPO_ID_SEPARATOR = "--" +# ^ this substring is not allowed in repo_ids on hf.co +# and is the canonical one we use for serialization of repo ids elsewhere. + REPO_TYPE_DATASET = "dataset" REPO_TYPE_SPACE = "space" diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 91912eda7c..bd7670401d 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -23,7 +23,10 @@ from .constants import ( DEFAULT_REVISION, HUGGINGFACE_CO_URL_TEMPLATE, + HUGGINGFACE_HEADER_X_LINKED_ETAG, + HUGGINGFACE_HEADER_X_REPO_COMMIT, HUGGINGFACE_HUB_CACHE, + REPO_ID_SEPARATOR, REPO_TYPES, REPO_TYPES_URL_PREFIXES, ) @@ -679,6 +682,10 @@ def _resumable_file_manager() -> "io.BufferedWriter": return cache_path +def normalize_etag(etag: str) -> str: + return etag.strip('"') + + @_deprecate_positional_args def hf_hub_download( repo_id: str, @@ -692,7 +699,6 @@ def hf_hub_download( cache_dir: Union[str, Path, None] = None, user_agent: Union[Dict, str, None] = None, force_download: Optional[bool] = False, - force_filename: Optional[str] = None, proxies: Optional[Dict] = None, etag_timeout: Optional[float] = 10, resume_download: Optional[bool] = False, @@ -725,8 +731,6 @@ def hf_hub_download( force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. - force_filename (`str`, *optional*): - Use this name instead of a generated file name. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. @@ -761,16 +765,29 @@ def hf_hub_download( """ - url = hf_hub_url( - repo_id, filename, subfolder=subfolder, repo_type=repo_type, revision=revision - ) - if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE + if revision is None: + revision = DEFAULT_REVISION if isinstance(cache_dir, Path): cache_dir = str(cache_dir) - os.makedirs(cache_dir, exist_ok=True) + if repo_type is None: + repo_type = "model" + if repo_type not in REPO_TYPES: + raise ValueError("Invalid repo type") + + # remove all `/` occurrences to correctly convert repo to directory name + repo_id_flattened = f"{repo_type}s{REPO_ID_SEPARATOR}" + repo_id.replace( + "/", REPO_ID_SEPARATOR + ) + storage_folder = os.path.join(cache_dir, repo_id_flattened) + + os.makedirs(storage_folder, exist_ok=True) + + url = hf_hub_url( + repo_id, filename, subfolder=subfolder, repo_type=repo_type, revision=revision + ) headers = { "user-agent": http_user_agent( @@ -785,12 +802,14 @@ def hf_hub_download( token = HfFolder.get_token() if token is None: raise EnvironmentError( - "You specified use_auth_token=True, but a huggingface token was not found." + "You specified use_auth_token=True, but a huggingface token was not" + " found." ) headers["authorization"] = f"Bearer {token}" url_to_download = url etag = None + commit_hash = None if not local_files_only: try: r = _request_with_retry( @@ -802,14 +821,24 @@ def hf_hub_download( timeout=etag_timeout, ) r.raise_for_status() - etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag") + commit_hash = r.headers[HUGGINGFACE_HEADER_X_REPO_COMMIT] + if commit_hash is None: + raise OSError( + "Distant resource does not seem to be the huggingface hub (missing" + " commit header)." + ) + etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get( + "ETag" + ) # We favor a custom header indicating the etag of the linked resource, and # we fallback to the regular etag header. # If we don't have any of those, raise an error. if etag is None: raise OSError( - "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." + "Distant resource does not have an ETag, we won't be able to" + " reliably ensure reproducibility." ) + etag = normalize_etag(etag) # In case of a redirect, # save an extra redirect on the request.get call, # and ensure we download the exact atomic version even if it changed @@ -828,11 +857,6 @@ def hf_hub_download( # etag is None pass - filename = url_to_filename(url, etag) - - # get cache path to put the file - cache_path = os.path.join(cache_dir, filename) - # etag is None == we don't have a connection or we passed local_files_only. # try to get the last downloaded one if etag is None: @@ -854,40 +878,62 @@ def hf_hub_download( # Notify the user about that if local_files_only: raise ValueError( - "Cannot find the requested files in the cached path and outgoing traffic has been" - " disabled. To enable model look-ups and downloads online, set 'local_files_only'" - " to False." + "Cannot find the requested files in the cached path and" + " outgoing traffic has been disabled. To enable model look-ups" + " and downloads online, set 'local_files_only' to False." ) else: raise ValueError( - "Connection error, and we cannot find the requested files in the cached path." - " Please try again or make sure your Internet connection is on." + "Connection error, and we cannot find the requested files in" + " the cached path. Please try again or make sure your Internet" + " connection is on." ) - # From now on, etag is not None. - if os.path.exists(cache_path) and not force_download: - return cache_path + # From now on, etag and commit_hash are not None. + blob_path = os.path.join(storage_folder, "blobs", etag) + pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, filename) + + os.makedirs(os.path.dirname(blob_path), exist_ok=True) + os.makedirs(os.path.dirname(pointer_path), exist_ok=True) + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. + if revision != commit_hash: + ref_path = os.path.join(storage_folder, "refs", revision) + os.makedirs(os.path.dirname(ref_path), exist_ok=True) + with open(ref_path, "w") as f: + f.write(commit_hash) + + if os.path.exists(pointer_path) and not force_download: + return pointer_path + + if os.path.exists(blob_path) and not force_download: + # we have the blob already, but not the pointer + logger.info("creating pointer to %s from %s", blob_path, pointer_path) + os.symlink(blob_path, pointer_path) + # TODO(should we try to do relative instead of absolute?) + return pointer_path # Prevent parallel downloads of the same file with a lock. - lock_path = cache_path + ".lock" + lock_path = blob_path + ".lock" # Some Windows versions do not allow for paths longer than 255 characters. # In this case, we must specify it is an extended path by using the "\\?\" prefix. if os.name == "nt" and len(os.path.abspath(lock_path)) > 255: lock_path = "\\\\?\\" + os.path.abspath(lock_path) - if os.name == "nt" and len(os.path.abspath(cache_path)) > 255: - cache_path = "\\\\?\\" + os.path.abspath(cache_path) + if os.name == "nt" and len(os.path.abspath(blob_path)) > 255: + blob_path = "\\\\?\\" + os.path.abspath(blob_path) with FileLock(lock_path): # If the download just completed while the lock was activated. - if os.path.exists(cache_path) and not force_download: + if os.path.exists(pointer_path) and not force_download: # Even if returning early like here, the lock will be released. - return cache_path + return pointer_path if resume_download: - incomplete_path = cache_path + ".incomplete" + incomplete_path = blob_path + ".incomplete" @contextmanager def _resumable_file_manager() -> "io.BufferedWriter": @@ -918,13 +964,16 @@ def _resumable_file_manager() -> "io.BufferedWriter": headers=headers, ) - logger.info("storing %s in cache at %s", url, cache_path) - os.replace(temp_file.name, cache_path) + logger.info("storing %s in cache at %s", url, blob_path) + os.replace(temp_file.name, blob_path) - logger.info("creating metadata file for %s", cache_path) - meta = {"url": url, "etag": etag} - meta_path = cache_path + ".json" - with open(meta_path, "w") as meta_file: - json.dump(meta, meta_file) + logger.info("creating pointer to %s from %s", blob_path, pointer_path) + os.symlink(blob_path, pointer_path) + # TODO(should we try to do relative instead of absolute?) - return cache_path + try: + os.remove(lock_path) + except OSError: + pass + + return pointer_path From 02480fd027a059ab55a04ac797a183b2f41f6bf2 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 4 May 2022 17:09:36 +0200 Subject: [PATCH 13/47] wip wip --- example.py | 2 +- setup.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/example.py b/example.py index 6dcdbd9ae4..77cd6c34ca 100644 --- a/example.py +++ b/example.py @@ -16,6 +16,6 @@ ) w = torch.load(weights_file, map_location=torch.device("cpu")) -### Yay it works! +# Yay it works! just loaded a torch file from a symlink print() diff --git a/setup.py b/setup.py index 748f39d14b..0ba23a6bd9 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,8 @@ def get_version() -> str: author="Hugging Face, Inc.", author_email="julien@huggingface.co", description=( - "Client library to download and publish models on the huggingface.co hub" + "Client library to download and publish models, datasets and other repos on the" + " huggingface.co hub" ), long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", From f81885d55d718a0e8acb5435cf95e7a10f4ef45d Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 4 May 2022 17:17:44 +0200 Subject: [PATCH 14/47] do not touch `cached_download` --- src/huggingface_hub/file_download.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index bd7670401d..e1808345cd 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -452,6 +452,7 @@ def cached_download( cache_dir: Union[str, Path, None] = None, user_agent: Union[Dict, str, None] = None, force_download: Optional[bool] = False, + force_filename: Optional[str] = None, proxies: Optional[Dict] = None, etag_timeout: Optional[float] = 10, resume_download: Optional[bool] = False, @@ -480,6 +481,8 @@ def cached_download( force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. + force_filename (`str`, *optional*): + Use this name instead of a generated file name. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. @@ -579,7 +582,9 @@ def cached_download( # etag is None pass - filename = url_to_filename(url, etag) + filename = ( + force_filename if force_filename is not None else url_to_filename(url, etag) + ) # get cache path to put the file cache_path = os.path.join(cache_dir, filename) @@ -597,7 +602,11 @@ def cached_download( ) if not file.endswith(".json") and not file.endswith(".lock") ] - if len(matching_files) > 0 and not force_download: + if ( + len(matching_files) > 0 + and not force_download + and force_filename is None + ): return os.path.join(cache_dir, matching_files[-1]) else: # If files cannot be found and local_files_only=True, @@ -673,11 +682,12 @@ def _resumable_file_manager() -> "io.BufferedWriter": logger.info("storing %s in cache at %s", url, cache_path) os.replace(temp_file.name, cache_path) - logger.info("creating metadata file for %s", cache_path) - meta = {"url": url, "etag": etag} - meta_path = cache_path + ".json" - with open(meta_path, "w") as meta_file: - json.dump(meta, meta_file) + if force_filename is None: + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + json.dump(meta, meta_file) return cache_path From 7fbd7a4523f7c958ac2607f0992bef6fc902f679 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 4 May 2022 17:21:58 +0200 Subject: [PATCH 15/47] Prompt user to upgrade to `hf_hub_download` --- src/huggingface_hub/file_download.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index e1808345cd..cbb853c2c8 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -11,6 +11,7 @@ from hashlib import sha256 from pathlib import Path from typing import BinaryIO, Dict, Optional, Tuple, Union +import warnings import packaging.version from tqdm.auto import tqdm @@ -517,6 +518,11 @@ def cached_download( """ + warnings.warn( + "`cached_download` is the legacy way to download files from the HF hub, please" + " consider upgrading to `hf_hub_download`", + FutureWarning, + ) if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE if isinstance(cache_dir, Path): From d1e47a109f3b5ec8c67c3a1b60491728dccc8105 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 4 May 2022 17:29:10 +0200 Subject: [PATCH 16/47] Add a `legacy_cache_layout=True` to preserve old behavior, just in case --- src/huggingface_hub/file_download.py | 30 +++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index cbb853c2c8..f86f0724ef 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -6,12 +6,12 @@ import sys import tempfile import time +import warnings from contextlib import contextmanager from functools import partial from hashlib import sha256 from pathlib import Path from typing import BinaryIO, Dict, Optional, Tuple, Union -import warnings import packaging.version from tqdm.auto import tqdm @@ -720,6 +720,7 @@ def hf_hub_download( resume_download: Optional[bool] = False, use_auth_token: Union[bool, str, None] = None, local_files_only: Optional[bool] = False, + legacy_cache_layout: Optional[bool] = False, ): """Download a given file if it's not already present in the local cache. @@ -763,6 +764,10 @@ def hf_hub_download( local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. + legacy_cache_layout (`bool`, *optional*, defaults to `False`): + If `True`, uses the legacy file cache layout i.e. just call `hf_hub_url` + then `cached_download`. This is deprecated as the new cache layout is + more powerful. Returns: Local path (string) of file or if networking is off, last version of @@ -781,6 +786,29 @@ def hf_hub_download( """ + if legacy_cache_layout: + url = hf_hub_url( + repo_id, + filename, + subfolder=subfolder, + repo_type=repo_type, + revision=revision, + ) + + return cached_download( + url, + library_name=library_name, + library_version=library_version, + cache_dir=cache_dir, + user_agent=user_agent, + force_download=force_download, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + ) + if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE if revision is None: From fc930b6226e843c8168fc29d3bd7e10e8a041263 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 14:31:08 +0200 Subject: [PATCH 17/47] Create `relative symlinks` + add some doc --- src/huggingface_hub/file_download.py | 49 ++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index f86f0724ef..4cee7d3c12 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -264,6 +264,10 @@ def filename_to_url(filename, cache_dir=None) -> Tuple[str, str]: Return the url and etag (which may be `None`) stored for `filename`. Raise `EnvironmentError` if `filename` or its stored metadata do not exist. """ + warnings.warn( + "`filename_to_url` uses the legacy way cache file layout", + FutureWarning, + ) if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE if isinstance(cache_dir, Path): @@ -698,10 +702,43 @@ def _resumable_file_manager() -> "io.BufferedWriter": return cache_path -def normalize_etag(etag: str) -> str: +def _normalize_etag(etag: str) -> str: + """Normalize ETag HTTP header, so it can be used to create nice filepaths. + + The HTTP spec allows two forms of ETag: + ETag: W/"" + ETag: "" + + The hf.co hub guarantees to only send the second form. + + Args: + etag (str): HTTP header + + Returns: + str: string that can be used as a nice directory name. + """ return etag.strip('"') +def _create_relative_symlink(src: str, dst: str) -> None: + """Create a symbolic link named dst pointing to src as a relative path to dst. + + The relative part is mostly because it seems more elegant to the author. + + The result layout looks something like + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + """ + relative_src = os.path.relpath(src, start=os.path.dirname(dst)) + try: + os.remove(dst) + except OSError: + pass + os.symlink(relative_src, dst) + + @_deprecate_positional_args def hf_hub_download( repo_id: str, @@ -868,7 +905,7 @@ def hf_hub_download( commit_hash = r.headers[HUGGINGFACE_HEADER_X_REPO_COMMIT] if commit_hash is None: raise OSError( - "Distant resource does not seem to be the huggingface hub (missing" + "Distant resource does not seem to be on huggingface.co (missing" " commit header)." ) etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get( @@ -882,7 +919,7 @@ def hf_hub_download( "Distant resource does not have an ETag, we won't be able to" " reliably ensure reproducibility." ) - etag = normalize_etag(etag) + etag = _normalize_etag(etag) # In case of a redirect, # save an extra redirect on the request.get call, # and ensure we download the exact atomic version even if it changed @@ -954,8 +991,7 @@ def hf_hub_download( if os.path.exists(blob_path) and not force_download: # we have the blob already, but not the pointer logger.info("creating pointer to %s from %s", blob_path, pointer_path) - os.symlink(blob_path, pointer_path) - # TODO(should we try to do relative instead of absolute?) + _create_relative_symlink(blob_path, pointer_path) return pointer_path # Prevent parallel downloads of the same file with a lock. @@ -1012,8 +1048,7 @@ def _resumable_file_manager() -> "io.BufferedWriter": os.replace(temp_file.name, blob_path) logger.info("creating pointer to %s from %s", blob_path, pointer_path) - os.symlink(blob_path, pointer_path) - # TODO(should we try to do relative instead of absolute?) + _create_relative_symlink(blob_path, pointer_path) try: os.remove(lock_path) From f5420ed91b9afb165e6bf14a62122a9488705d9c Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 14:51:00 +0200 Subject: [PATCH 18/47] Fix behavior when no network --- src/huggingface_hub/file_download.py | 71 +++++++++++++++++----------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 4cee7d3c12..fe4daa3f5b 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -3,6 +3,7 @@ import io import json import os +import re import sys import tempfile import time @@ -146,6 +147,9 @@ def get_fastcore_version(): return _fastcore_version +REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{5,40}$") + + @_deprecate_positional_args def hf_hub_url( repo_id: str, @@ -939,36 +943,47 @@ def hf_hub_download( pass # etag is None == we don't have a connection or we passed local_files_only. - # try to get the last downloaded one + # try to get the last downloaded one from the specified revision. + # If the specified revision is a commit hash, look inside "snapshots". + # If the specified revision is a branch or tag, look inside "refs". if etag is None: - if os.path.exists(cache_path) and not force_download: - return cache_path + # In those cases, we cannot force download. + if force_download: + raise ValueError( + "We have no connection or you passed local_files_only, so" + " force_download is not an accepted option." + ) + if REGEX_COMMIT_HASH.match(revision): + pointer_path = os.path.join(storage_folder, "snapshots", revision, filename) + if os.path.exists(pointer_path): + return pointer_path else: - matching_files = [ - file - for file in fnmatch.filter( - os.listdir(cache_dir), filename.split(".")[0] + ".*" - ) - if not file.endswith(".json") and not file.endswith(".lock") - ] - if len(matching_files) > 0 and not force_download: - return os.path.join(cache_dir, matching_files[-1]) - else: - # If files cannot be found and local_files_only=True, - # the models might've been found if local_files_only=False - # Notify the user about that - if local_files_only: - raise ValueError( - "Cannot find the requested files in the cached path and" - " outgoing traffic has been disabled. To enable model look-ups" - " and downloads online, set 'local_files_only' to False." - ) - else: - raise ValueError( - "Connection error, and we cannot find the requested files in" - " the cached path. Please try again or make sure your Internet" - " connection is on." - ) + ref_path = os.path.join(storage_folder, "refs", revision) + with open(ref_path) as f: + commit_hash = f.read() + pointer_path = os.path.join( + storage_folder, "snapshots", commit_hash, filename + ) + if os.path.exists(pointer_path): + return pointer_path + + # If we couldn't find an appropriate file on disk, + # raise an error. + # If files cannot be found and local_files_only=True, + # the models might've been found if local_files_only=False + # Notify the user about that + if local_files_only: + raise ValueError( + "Cannot find the requested files in the disk cache and" + " outgoing traffic has been disabled. To enable hf.co look-ups" + " and downloads online, set 'local_files_only' to False." + ) + else: + raise ValueError( + "Connection error, and we cannot find the requested files in" + " the disk cache. Please try again or make sure your Internet" + " connection is on." + ) # From now on, etag and commit_hash are not None. blob_path = os.path.join(storage_folder, "blobs", etag) From 9f1e0f6f01728cc46f7e3876dbc45e7792e8d919 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 14:59:35 +0200 Subject: [PATCH 19/47] This test now is legacy --- tests/test_file_download.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 05499e380a..98fd055d56 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -154,12 +154,13 @@ def test_dataset_lfs_object(self): (url, '"95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"'), ) - def test_hf_hub_download(self): + def test_hf_hub_download_legacy(self): filepath = hf_hub_download( DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT, force_download=True, + legacy_cache_layout=True, ) metadata = filename_to_url(filepath) self.assertEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"') From 973913e093114bafaa5f369ed9e7129dc22ddcee Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 15:08:07 +0200 Subject: [PATCH 20/47] Fix-ish conflict-ish --- src/huggingface_hub/_snapshot_download.py | 126 +++++++++++----------- 1 file changed, 61 insertions(+), 65 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 610abea88a..5fd4342192 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -39,71 +39,67 @@ def snapshot_download( ) -> str: """Download all files of a repo. - Downloads a whole snapshot of a repo's files at the specified revision. This - is useful when you want all files from a repo, because you don't know which - ones you will need a priori. All files are nested inside a folder in order - to keep their actual filename relative to that folder. - - An alternative would be to just clone a repo but this would require that the - user always has git and git-lfs installed, and properly configured. - - Args: - repo_id (`str`): - A user or an organization name and a repo name separated by a `/`. - revision (`str`, *optional*): - An optional Git revision id which can be a branch name, a tag, or a - commit hash. - <<<<<<< HEAD - cache_dir (`str`, `Path`, *optional*): - ======= - repo_type: Set to :obj:`"dataset"` or :obj:`"space"` if downloading - a dataset or space, :obj:`None` or :obj:`"model"` if - downloading a model. Default is :obj:`None`. - cache_dir (``str``, ``Path``, `optional`): - >>>>>>> b748712 (one more) - Path to the folder where cached files are stored. - library_name (`str`, *optional*): - The name of the library to which the object corresponds. - library_version (`str`, *optional*): - The version of the library. - user_agent (`str`, `dict`, *optional*): - The user-agent info in the form of a dictionary or a string. - proxies (`dict`, *optional*): - Dictionary mapping protocol to the URL of the proxy passed to - `requests.request`. - etag_timeout (`float`, *optional*, defaults to `10`): - When fetching ETag, how many seconds to wait for the server to send - data before giving up which is passed to `requests.request`. - resume_download (`bool`, *optional*, defaults to `False): - If `True`, resume a previously interrupted download. - use_auth_token (`str`, `bool`, *optional*): - A token to be used for the download. - - If `True`, the token is read from the HuggingFace config - folder. - - If a string, it's used as the authentication token. - local_files_only (`bool`, *optional*, defaults to `False`): - If `True`, avoid downloading the file and return the path to the - local cached file if it exists. - allow_regex (`list of str`, `str`, *optional*): - If provided, only files matching this regex are downloaded. - ignore_regex (`list of str`, `str`, *optional*): - If provided, files matching this regex are not downloaded. - - Returns: - Local folder path (string) of repo snapshot - - - - Raises the following errors: - - - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) - if `use_auth_token=True` and the token cannot be found. - - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if - ETag cannot be determined. - - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) - if some parameter value is invalid - - + Downloads a whole snapshot of a repo's files at the specified revision. This + is useful when you want all files from a repo, because you don't know which + ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. + + An alternative would be to just clone a repo but this would require that the + user always has git and git-lfs installed, and properly configured. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or space, + `None` or `"model"` if uploading to a model. Default is `None`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + library_name (`str`, *optional*): + The name of the library to which the object corresponds. + library_version (`str`, *optional*): + The version of the library. + user_agent (`str`, `dict`, *optional*): + The user-agent info in the form of a dictionary or a string. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + resume_download (`bool`, *optional*, defaults to `False): + If `True`, resume a previously interrupted download. + use_auth_token (`str`, `bool`, *optional*): + A token to be used for the download. + - If `True`, the token is read from the HuggingFace config + folder. + - If a string, it's used as the authentication token. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + allow_regex (`list of str`, `str`, *optional*): + If provided, only files matching this regex are downloaded. + ignore_regex (`list of str`, `str`, *optional*): + If provided, files matching this regex are not downloaded. + + Returns: + Local folder path (string) of repo snapshot + + + + Raises the following errors: + + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `use_auth_token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + + """ if cache_dir is None: From e9fb4d40f8458c13b9016ad410d757c5b60346de Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 15:09:14 +0200 Subject: [PATCH 21/47] minimize diff --- src/huggingface_hub/_snapshot_download.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 5fd4342192..82f94ac4c1 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -75,7 +75,7 @@ def snapshot_download( use_auth_token (`str`, `bool`, *optional*): A token to be used for the download. - If `True`, the token is read from the HuggingFace config - folder. + folder. - If a string, it's used as the authentication token. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the @@ -93,11 +93,11 @@ def snapshot_download( Raises the following errors: - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) - if `use_auth_token=True` and the token cannot be found. + if `use_auth_token=True` and the token cannot be found. - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if - ETag cannot be determined. + ETag cannot be determined. - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) - if some parameter value is invalid + if some parameter value is invalid """ From adaca46d524225dda418daa6f4a1827b00f7e9f6 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 15:32:11 +0200 Subject: [PATCH 22/47] refactor `repo_folder_name` --- src/huggingface_hub/file_download.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index fe4daa3f5b..57915b7415 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -743,6 +743,24 @@ def _create_relative_symlink(src: str, dst: str) -> None: os.symlink(relative_src, dst) +def repo_folder_name( + *, + repo_id: str, + repo_type: str, +) -> str: + """Return a serialized version of a hf.co repo name and type, safe for disk storage + as a single non-nested folder. + + Example: models--julien-c--EsperBERTo-small + """ + # remove all `/` occurrences to correctly convert repo to directory name + parts = [ + f"{repo_type}s", + *repo_id.split("/"), + ] + return REPO_ID_SEPARATOR.join(parts) + + @_deprecate_positional_args def hf_hub_download( repo_id: str, @@ -862,11 +880,9 @@ def hf_hub_download( if repo_type not in REPO_TYPES: raise ValueError("Invalid repo type") - # remove all `/` occurrences to correctly convert repo to directory name - repo_id_flattened = f"{repo_type}s{REPO_ID_SEPARATOR}" + repo_id.replace( - "/", REPO_ID_SEPARATOR + storage_folder = os.path.join( + cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) ) - storage_folder = os.path.join(cache_dir, repo_id_flattened) os.makedirs(storage_folder, exist_ok=True) From fff7e9f7af4cb52ecc148f1286aed93bceba77fc Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 15:53:25 +0200 Subject: [PATCH 23/47] windows support + shortcut if user passes a commit hash --- src/huggingface_hub/file_download.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 57915b7415..5728c9e955 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -883,9 +883,20 @@ def hf_hub_download( storage_folder = os.path.join( cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) ) - os.makedirs(storage_folder, exist_ok=True) + # cross platform transcription of filename, to be used a local file path. + relative_filename = os.path.join(*filename.split("/")) + + # if user provides a commit_hash and they already have the file on disk, + # shortcut everything. + if REGEX_COMMIT_HASH.match(revision): + pointer_path = os.path.join( + storage_folder, "snapshots", revision, relative_filename + ) + if os.path.exists(pointer_path): + return pointer_path + url = hf_hub_url( repo_id, filename, subfolder=subfolder, repo_type=repo_type, revision=revision ) @@ -970,7 +981,9 @@ def hf_hub_download( " force_download is not an accepted option." ) if REGEX_COMMIT_HASH.match(revision): - pointer_path = os.path.join(storage_folder, "snapshots", revision, filename) + pointer_path = os.path.join( + storage_folder, "snapshots", revision, relative_filename + ) if os.path.exists(pointer_path): return pointer_path else: @@ -978,7 +991,7 @@ def hf_hub_download( with open(ref_path) as f: commit_hash = f.read() pointer_path = os.path.join( - storage_folder, "snapshots", commit_hash, filename + storage_folder, "snapshots", commit_hash, relative_filename ) if os.path.exists(pointer_path): return pointer_path @@ -1003,7 +1016,9 @@ def hf_hub_download( # From now on, etag and commit_hash are not None. blob_path = os.path.join(storage_folder, "blobs", etag) - pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, filename) + pointer_path = os.path.join( + storage_folder, "snapshots", commit_hash, relative_filename + ) os.makedirs(os.path.dirname(blob_path), exist_ok=True) os.makedirs(os.path.dirname(pointer_path), exist_ok=True) From 06c0ca157cc49695c2d8cffded6ab6fc8a3a601b Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 16:03:45 +0200 Subject: [PATCH 24/47] Rewrite `snapshot_download` and make it more robust --- src/huggingface_hub/_snapshot_download.py | 166 ++++++++-------------- src/huggingface_hub/file_download.py | 2 +- 2 files changed, 63 insertions(+), 105 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 82f94ac4c1..845cc5c44d 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -1,16 +1,10 @@ import os from fnmatch import fnmatch -from glob import glob from pathlib import Path from typing import Dict, List, Optional, Union -from .constants import ( - DEFAULT_REVISION, - HUGGINGFACE_HUB_CACHE, - REPO_ID_SEPARATOR, - REPO_TYPES, -) -from .file_download import cached_download, hf_hub_download, hf_hub_url +from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE, REPO_TYPES +from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name from .hf_api import HfApi, HfFolder from .utils import logging from .utils._deprecation import _deprecate_positional_args @@ -19,6 +13,31 @@ logger = logging.get_logger(__name__) +def _filter_repo_files( + *repo_files: List[str], + allow_regex: Optional[Union[List[str], str]] = None, + ignore_regex: Optional[Union[List[str], str]] = None, +) -> List[str]: + allow_regex = [allow_regex] if isinstance(allow_regex, str) else allow_regex + ignore_regex = [ignore_regex] if isinstance(ignore_regex, str) else ignore_regex + filtered_files = [] + for repo_file in repo_files: + # if there's an allowlist, skip download if file does not match any regex + if allow_regex is not None and not any( + fnmatch(repo_file, r) for r in allow_regex + ): + continue + + # if there's a denylist, skip download if file does matches any regex + if ignore_regex is not None and any( + fnmatch(repo_file, r) for r in ignore_regex + ): + continue + + filtered_files.append(repo_file) + return filtered_files + + @_deprecate_positional_args def snapshot_download( repo_id: str, @@ -126,116 +145,56 @@ def snapshot_download( if repo_type not in REPO_TYPES: raise ValueError("Invalid repo type") - # remove all `/` occurrences to correctly convert repo to directory name - repo_id_flattened = f"{repo_type}s{REPO_ID_SEPARATOR}" + repo_id.replace( - "/", REPO_ID_SEPARATOR + storage_folder = os.path.join( + cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) ) - # if we have no internet connection we will look for the - # last modified folder in the cache + # if we have no internet connection we will look for an + # appropriate folder in the cache + # If the specified revision is a commit hash, look inside "snapshots". + # If the specified revision is a branch or tag, look inside "refs". if local_files_only: - # possible repos have / prefix - repo_folders_prefix = os.path.join(cache_dir, repo_id_flattened) - - # list all possible folders that can correspond to the repo_id - # and are of the format .. - # now let's list all cached repos that have to be included in the revision. - # There are 3 cases that we have to consider. - - # 1) cached repos of format .{revision}. - # -> in this case {revision} has to be a branch - repo_folders_branch = glob(repo_folders_prefix + "." + revision + ".*") - - # 2) cached repos of format .{revision} - # -> in this case {revision} has to be a commit sha - repo_folders_commit_only = glob(repo_folders_prefix + "." + revision) - - # 3) cached repos of format ..{revision} - # -> in this case {revision} also has to be a commit sha - repo_folders_branch_commit = glob(repo_folders_prefix + ".*." + revision) - - # combine all possible fetched cached repos - repo_folders = ( - repo_folders_branch + repo_folders_commit_only + repo_folders_branch_commit + if REGEX_COMMIT_HASH.match(revision): + snapshot_folder = os.path.join(storage_folder, "snapshots", revision) + if os.path.exists(snapshot_folder): + return snapshot_folder + else: + ref_path = os.path.join(storage_folder, "refs", revision) + with open(ref_path) as f: + commit_hash = f.read() + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + if os.path.exists(snapshot_folder): + return snapshot_folder + + raise ValueError( + "Cannot find an appropriate cached folder for the specified revision on the" + " local disk and outgoing traffic has been disabled. To enable repo" + " look-ups and downloads online, set 'local_files_only' to False." ) - if len(repo_folders) == 0: - raise ValueError( - "Cannot find the requested files in the cached path and outgoing" - " traffic has been disabled. To enable repo look-ups and downloads" - " online, set 'local_files_only' to False." - ) - - # check if repo id was previously cached from a commit sha revision - # and passed {revision} is not a commit sha - # in this case snapshotting repos locally might lead to unexpected - # behavior the user should be warned about - - # get all folders that were cached with just a sha commit revision - all_repo_folders_from_sha = set(glob(repo_folders_prefix + ".*")) - set( - glob(repo_folders_prefix + ".*.*") - ) - # 1) is there any repo id that was previously cached from a commit sha? - has_a_sha_revision_been_cached = len(all_repo_folders_from_sha) > 0 - # 2) is the passed {revision} is a branch - is_revision_a_branch = ( - len(repo_folders_commit_only + repo_folders_branch_commit) == 0 - ) - - if has_a_sha_revision_been_cached and is_revision_a_branch: - # -> in this case let's warn the user - logger.warn( - f"The repo {repo_id} was previously downloaded from a commit hash" - " revision and has created the following cached directories" - f" {all_repo_folders_from_sha}. In this case, trying to load a repo" - f" from the branch {revision} in offline mode might lead to unexpected" - " behavior by not taking into account the latest commits." - ) - - # find last modified folder, and return it - storage_folder = max(repo_folders, key=os.path.getmtime) - - return storage_folder - # if we have internet connection we retrieve the correct folder name from the huggingface api _api = HfApi() repo_info = _api.repo_info( repo_id=repo_id, repo_type=repo_type, revision=revision, token=token ) + filtered_repo_files = _filter_repo_files( + repo_files=[f.rfilename for f in repo_info.siblings], + allow_regex=allow_regex, + ignore_regex=ignore_regex, + ) + commit_hash = repo_info.sha + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) - storage_folder = os.path.join(cache_dir, repo_id_flattened) - - # if passed revision is not identical to the commit sha - # then revision has to be a branch name, e.g. "main" - # in this case make sure that the branch name is included - # cached storage folder name - # if revision != repo_info.sha: - # storage_folder += f".{repo_info.sha}" - - repo_id_sha = repo_info.sha - repo_files = [f.rfilename for f in repo_info.siblings] - - allow_regex = [allow_regex] if isinstance(allow_regex, str) else allow_regex - ignore_regex = [ignore_regex] if isinstance(ignore_regex, str) else ignore_regex - - for repo_file in repo_files: - # if there's an allowlist, skip download if file does not match any regex - if allow_regex is not None and not any( - fnmatch(repo_file, r) for r in allow_regex - ): - continue - - # if there's a denylist, skip download if file does matches any regex - if ignore_regex is not None and any( - fnmatch(repo_file, r) for r in ignore_regex - ): - continue + # we pass the commit_hash to hf_hub_download + # so no network call happens if we already + # have the file locally. + for repo_file in filtered_repo_files: _ = hf_hub_download( repo_id, filename=repo_file, repo_type=repo_type, - revision=repo_id_sha, + revision=commit_hash, cache_dir=storage_folder, library_name=library_name, library_version=library_version, @@ -244,7 +203,6 @@ def snapshot_download( etag_timeout=etag_timeout, resume_download=resume_download, use_auth_token=use_auth_token, - local_files_only=local_files_only, ) - return storage_folder + return snapshot_folder diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 5728c9e955..896acfb29f 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -885,7 +885,7 @@ def hf_hub_download( ) os.makedirs(storage_folder, exist_ok=True) - # cross platform transcription of filename, to be used a local file path. + # cross platform transcription of filename, to be used as a local file path. relative_filename = os.path.join(*filename.split("/")) # if user provides a commit_hash and they already have the file on disk, From e3e2485b4a2585ab04203479601ac1bc8af902c7 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 16:10:25 +0200 Subject: [PATCH 25/47] OOops --- src/huggingface_hub/_snapshot_download.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 845cc5c44d..42c0a1ae0c 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -14,7 +14,8 @@ def _filter_repo_files( - *repo_files: List[str], + *, + repo_files: List[str], allow_regex: Optional[Union[List[str], str]] = None, ignore_regex: Optional[Union[List[str], str]] = None, ) -> List[str]: From 70516833d7a57de852e824aed833cfe8646dab7a Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 16:17:17 +0200 Subject: [PATCH 26/47] Create example-transformers-tf.py --- example-transformers-tf.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 example-transformers-tf.py diff --git a/example-transformers-tf.py b/example-transformers-tf.py new file mode 100644 index 0000000000..ea050012a4 --- /dev/null +++ b/example-transformers-tf.py @@ -0,0 +1,16 @@ +from huggingface_hub.snapshot_download import snapshot_download +from huggingface_hub.utils.logging import set_verbosity_debug + +set_verbosity_debug() + +DISTILBERT = "distilbert-base-uncased" + +folder_path = snapshot_download( + repo_id=DISTILBERT, + repo_type="model", +) + +print("loading TF model from", folder_path) + + +print() From 84ff68df69e728d70029df0b96f89b578b91ad64 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 5 May 2022 16:33:32 +0200 Subject: [PATCH 27/47] Fix + add a way more complete example (running on Ubuntu) --- example-transformers-tf.py | 8 +++++++- src/huggingface_hub/_snapshot_download.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/example-transformers-tf.py b/example-transformers-tf.py index ea050012a4..a7fceb9cb2 100644 --- a/example-transformers-tf.py +++ b/example-transformers-tf.py @@ -1,5 +1,7 @@ from huggingface_hub.snapshot_download import snapshot_download from huggingface_hub.utils.logging import set_verbosity_debug +from transformers import AutoModelForMaskedLM, TFAutoModelForMaskedLM + set_verbosity_debug() @@ -10,7 +12,11 @@ repo_type="model", ) -print("loading TF model from", folder_path) +print("The whole model repo has been saved to", folder_path) + +pt_model = AutoModelForMaskedLM.from_pretrained(folder_path) +tf_model = TFAutoModelForMaskedLM.from_pretrained(folder_path) +# Yay it works! print() diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 42c0a1ae0c..f4612171b3 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -196,7 +196,7 @@ def snapshot_download( filename=repo_file, repo_type=repo_type, revision=commit_hash, - cache_dir=storage_folder, + cache_dir=cache_dir, library_name=library_name, library_version=library_version, user_agent=user_agent, From 11974d4840fab1e758c3eb92d97c4d5add6c8a1c Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 6 May 2022 11:14:02 +0200 Subject: [PATCH 28/47] Apply suggestions from code review Co-authored-by: Lysandre Debut Co-authored-by: Patrick von Platen --- src/huggingface_hub/_snapshot_download.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index f4612171b3..a4eaad8f51 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -144,7 +144,7 @@ def snapshot_download( if repo_type is None: repo_type = "model" if repo_type not in REPO_TYPES: - raise ValueError("Invalid repo type") + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") storage_folder = os.path.join( cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) @@ -168,7 +168,7 @@ def snapshot_download( return snapshot_folder raise ValueError( - "Cannot find an appropriate cached folder for the specified revision on the" + "Cannot find an appropriate cached snapshot folder for the specified revision on the" " local disk and outgoing traffic has been disabled. To enable repo" " look-ups and downloads online, set 'local_files_only' to False." ) @@ -185,6 +185,11 @@ def snapshot_download( ) commit_hash = repo_info.sha snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + if revision != commit_hash: + ref_path = os.path.join(storage_folder, "refs", revision) + os.makedirs(os.path.dirname(ref_path), exist_ok=True) + with open(ref_path, "w") as f: + f.write(commit_hash) # we pass the commit_hash to hf_hub_download # so no network call happens if we already From 73cdb4672923c550d5e38bcc1ea2ba89dbab4589 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 9 May 2022 14:26:22 +0200 Subject: [PATCH 29/47] Update src/huggingface_hub/file_download.py Co-authored-by: Lysandre Debut --- src/huggingface_hub/file_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 896acfb29f..ba698615ea 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -719,7 +719,7 @@ def _normalize_etag(etag: str) -> str: etag (str): HTTP header Returns: - str: string that can be used as a nice directory name. + `str`: string that can be used as a nice directory name. """ return etag.strip('"') From 9f8b8769bcfaf677ce765ee5764285c63ace9ea7 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 9 May 2022 14:26:28 +0200 Subject: [PATCH 30/47] Update src/huggingface_hub/file_download.py Co-authored-by: Lysandre Debut --- src/huggingface_hub/file_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index ba698615ea..22595388e2 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -716,7 +716,7 @@ def _normalize_etag(etag: str) -> str: The hf.co hub guarantees to only send the second form. Args: - etag (str): HTTP header + etag (`str`): HTTP header Returns: `str`: string that can be used as a nice directory name. From 2c5c94b2e41226d10fa0353701a374c2ffd9294e Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 9 May 2022 14:45:34 +0200 Subject: [PATCH 31/47] Only allow full revision hashes otherwise the `revision != commit_hash` test is not reliable --- src/huggingface_hub/file_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 22595388e2..8ad1b95404 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -147,7 +147,7 @@ def get_fastcore_version(): return _fastcore_version -REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{5,40}$") +REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$") @_deprecate_positional_args From 35fff445e5c9de815786566cad5bf6b302717b01 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 9 May 2022 14:47:05 +0200 Subject: [PATCH 32/47] add a little bit more doc + consistency --- src/huggingface_hub/_snapshot_download.py | 15 ++++++++--- src/huggingface_hub/file_download.py | 31 ++++++++++++++++++++++- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index a4eaad8f51..726f705ea7 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -144,7 +144,10 @@ def snapshot_download( if repo_type is None: repo_type = "model" if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") + raise ValueError( + f"Invalid repo type: {repo_type}. Accepted repo types are:" + f" {str(REPO_TYPES)}" + ) storage_folder = os.path.join( cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) @@ -168,9 +171,10 @@ def snapshot_download( return snapshot_folder raise ValueError( - "Cannot find an appropriate cached snapshot folder for the specified revision on the" - " local disk and outgoing traffic has been disabled. To enable repo" - " look-ups and downloads online, set 'local_files_only' to False." + "Cannot find an appropriate cached snapshot folder for the specified" + " revision on the local disk and outgoing traffic has been disabled. To" + " enable repo look-ups and downloads online, set 'local_files_only' to" + " False." ) # if we have internet connection we retrieve the correct folder name from the huggingface api @@ -185,6 +189,9 @@ def snapshot_download( ) commit_hash = repo_info.sha snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. if revision != commit_hash: ref_path = os.path.join(storage_folder, "refs", revision) os.makedirs(os.path.dirname(ref_path), exist_ok=True) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 8ad1b95404..66d807a241 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -783,6 +783,32 @@ def hf_hub_download( ): """Download a given file if it's not already present in the local cache. + The new cache file layout looks like this: + - The cache directory contains one subfolder per repo_id (namespaced by repo type) + - inside each repo folder: + - refs is a list of the latest known revision => commit_hash pairs + - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on + whether they're LFS files or not) + - snapshots contains one subfolder per commit, each "commit" contains the subset of the files + that have been resolved at that particular commit. Each filename is a symlink to the blob + at that particular commit. + + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. @@ -878,7 +904,10 @@ def hf_hub_download( if repo_type is None: repo_type = "model" if repo_type not in REPO_TYPES: - raise ValueError("Invalid repo type") + raise ValueError( + f"Invalid repo type: {repo_type}. Accepted repo types are:" + f" {str(REPO_TYPES)}" + ) storage_folder = os.path.join( cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) From 383022fdab9fe49dcb37c542284f50e1b99d7c3d Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Fri, 13 May 2022 14:18:05 -0400 Subject: [PATCH 33/47] Update src/huggingface_hub/snapshot_download.py Co-authored-by: Patrick von Platen --- src/huggingface_hub/_snapshot_download.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 726f705ea7..8f708cf3f5 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -158,7 +158,16 @@ def snapshot_download( # If the specified revision is a commit hash, look inside "snapshots". # If the specified revision is a branch or tag, look inside "refs". if local_files_only: - if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + if not REGEX_COMMIT_HASH.match(commit_hash): + # rertieve commit_hash from file + ref_path = os.path.join(storage_folder, "refs", revision) + with open(ref_path) as f: + commit_hash = f.read() + + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + if os.path.exists(snapshot_folder): + return snapshot_folder snapshot_folder = os.path.join(storage_folder, "snapshots", revision) if os.path.exists(snapshot_folder): return snapshot_folder From 7112146b700b12758f69036139a5c6e9e7593d90 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 18 May 2022 15:53:30 -0400 Subject: [PATCH 34/47] Update snapshot download --- src/huggingface_hub/_snapshot_download.py | 2 +- src/huggingface_hub/file_download.py | 20 ++++++++------------ 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 8f708cf3f5..8da36594a4 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -164,7 +164,7 @@ def snapshot_download( ref_path = os.path.join(storage_folder, "refs", revision) with open(ref_path) as f: commit_hash = f.read() - + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) if os.path.exists(snapshot_folder): return snapshot_folder diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 66d807a241..d95e7fcdb0 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -1009,21 +1009,17 @@ def hf_hub_download( "We have no connection or you passed local_files_only, so" " force_download is not an accepted option." ) - if REGEX_COMMIT_HASH.match(revision): - pointer_path = os.path.join( - storage_folder, "snapshots", revision, relative_filename - ) - if os.path.exists(pointer_path): - return pointer_path - else: + commit_hash = revision + if not REGEX_COMMIT_HASH.match(revision): ref_path = os.path.join(storage_folder, "refs", revision) with open(ref_path) as f: commit_hash = f.read() - pointer_path = os.path.join( - storage_folder, "snapshots", commit_hash, relative_filename - ) - if os.path.exists(pointer_path): - return pointer_path + + pointer_path = os.path.join( + storage_folder, "snapshots", commit_hash, relative_filename + ) + if os.path.exists(pointer_path): + return pointer_path # If we couldn't find an appropriate file on disk, # raise an error. From 8f99d2267471027052a1ebb48b638e0ce7c3d715 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 18 May 2022 19:39:34 -0400 Subject: [PATCH 35/47] First pass on tests --- tests/test_cache_layout.py | 170 +++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 tests/test_cache_layout.py diff --git a/tests/test_cache_layout.py b/tests/test_cache_layout.py new file mode 100644 index 0000000000..4ca8f242a6 --- /dev/null +++ b/tests/test_cache_layout.py @@ -0,0 +1,170 @@ +import os +import tempfile +import time +import unittest + +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import logging + +from .testing_utils import with_production_testing + + +logger = logging.get_logger(__name__) +MODEL_IDENTIFIER = "hf-internal-testing/hfh-cache-layout" + + +@with_production_testing +class CacheFileLayout(unittest.TestCase): + def test_file_downloaded_in_cache(self): + with tempfile.TemporaryDirectory() as cache: + hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(HUGGINGFACE_HUB_CACHE, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + + expected_reference = "main" + + # Only reference should be `main`. + self.assertListEqual(refs, [expected_reference]) + + with open(os.path.join(expected_path, "refs", expected_reference)) as f: + snapshot_name = f.readline().strip() + + # The `main` reference should point to the only snapshot we have downloaded + self.assertListEqual(snapshots, [snapshot_name]) + + snapshot_path = os.path.join(expected_path, "snapshots", snapshot_name) + snapshot_content = os.listdir(snapshot_path) + + # Only a single file in the snapshot + self.assertEqual(len(snapshot_content), 1) + + snapshot_content_path = os.path.join(snapshot_path, snapshot_content[0]) + + # The snapshot content should link to a blob + self.assertTrue(os.path.islink(snapshot_content_path)) + + resolved_blob_relative = os.readlink(snapshot_content_path) + resolved_blob_absolute = os.path.normpath( + os.path.join(snapshot_path, resolved_blob_relative) + ) + + with open(resolved_blob_absolute) as f: + blob_contents = f.readline().strip() + + # The contents of the file should be 'File 0'. + self.assertEqual(blob_contents, "File 0") + + def test_file_downloaded_in_cache_with_revision(self): + with tempfile.TemporaryDirectory() as cache: + hf_hub_download( + MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2" + ) + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + + expected_reference = "file-2" + + # Only reference should be `file-2`. + self.assertListEqual(refs, [expected_reference]) + + with open(os.path.join(expected_path, "refs", expected_reference)) as f: + snapshot_name = f.readline().strip() + + # The `main` reference should point to the only snapshot we have downloaded + self.assertListEqual(snapshots, [snapshot_name]) + + snapshot_path = os.path.join(expected_path, "snapshots", snapshot_name) + snapshot_content = os.listdir(snapshot_path) + + # Only a single file in the snapshot + self.assertEqual(len(snapshot_content), 1) + + snapshot_content_path = os.path.join(snapshot_path, snapshot_content[0]) + + # The snapshot content should link to a blob + self.assertTrue(os.path.islink(snapshot_content_path)) + + resolved_blob_relative = os.readlink(snapshot_content_path) + resolved_blob_absolute = os.path.normpath( + os.path.join(snapshot_path, resolved_blob_relative) + ) + + with open(resolved_blob_absolute) as f: + blob_contents = f.readline().strip() + + # The contents of the file should be 'File 0'. + self.assertEqual(blob_contents, "File 0") + + def test_file_download_happens_once(self): + # Tests that a file is only downloaded once if it's not updated. + + with tempfile.TemporaryDirectory() as cache: + path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + creation_time_0 = os.path.getmtime(path) + + time.sleep(2) + + path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + creation_time_1 = os.path.getmtime(path) + + self.assertEqual(creation_time_0, creation_time_1) + + def test_file_download_happens_once_intra_revision(self): + # Tests that a file is only downloaded once if it's not updated, even across different revisions. + + with tempfile.TemporaryDirectory() as cache: + path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + creation_time_0 = os.path.getmtime(path) + + time.sleep(2) + + path = hf_hub_download( + MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2" + ) + creation_time_1 = os.path.getmtime(path) + + self.assertEqual(creation_time_0, creation_time_1) + + def test_multiple_refs_for_same_file(self): + with tempfile.TemporaryDirectory() as cache: + hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + hf_hub_download( + MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2" + ) + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + refs.sort() + + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + snapshots.sort() + + # Directory should contain two revisions + self.assertListEqual(refs, ["file-2", "main"]) + + def get_file_contents(path): + with open(path) as f: + content = f.read() + + return content + + refs_contents = [ + get_file_contents(os.path.join(expected_path, "refs", f)) for f in refs + ] + refs_contents.sort() + + # snapshots directory should contain two snapshots + self.assertListEqual(refs_contents, snapshots) + + # snapshots_paths = [os.path.join(expected_path, s) for s in snapshots] From 01e3a61d32cb3e0b42da8dc3e63edb870383a600 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 20 May 2022 16:51:18 -0400 Subject: [PATCH 36/47] Wrap up tests --- tests/test_cache_layout.py | 239 +++++++++++++++++++++++++++++++++++-- 1 file changed, 227 insertions(+), 12 deletions(-) diff --git a/tests/test_cache_layout.py b/tests/test_cache_layout.py index 4ca8f242a6..fd9a39d56d 100644 --- a/tests/test_cache_layout.py +++ b/tests/test_cache_layout.py @@ -2,11 +2,19 @@ import tempfile import time import unittest - -from huggingface_hub import hf_hub_download -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from io import BytesIO + +from huggingface_hub import ( + HfApi, + create_repo, + delete_repo, + hf_hub_download, + snapshot_download, + upload_file, +) from huggingface_hub.utils import logging +from .testing_constants import ENDPOINT_STAGING, TOKEN, USER from .testing_utils import with_production_testing @@ -14,14 +22,21 @@ MODEL_IDENTIFIER = "hf-internal-testing/hfh-cache-layout" +def get_file_contents(path): + with open(path) as f: + content = f.read() + + return content + + @with_production_testing -class CacheFileLayout(unittest.TestCase): +class CacheFileLayoutHfHubDownload(unittest.TestCase): def test_file_downloaded_in_cache(self): with tempfile.TemporaryDirectory() as cache: hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' - expected_path = os.path.join(HUGGINGFACE_HUB_CACHE, expected_directory_name) + expected_path = os.path.join(cache, expected_directory_name) refs = os.listdir(os.path.join(expected_path, "refs")) snapshots = os.listdir(os.path.join(expected_path, "snapshots")) @@ -153,12 +168,6 @@ def test_multiple_refs_for_same_file(self): # Directory should contain two revisions self.assertListEqual(refs, ["file-2", "main"]) - def get_file_contents(path): - with open(path) as f: - content = f.read() - - return content - refs_contents = [ get_file_contents(os.path.join(expected_path, "refs", f)) for f in refs ] @@ -167,4 +176,210 @@ def get_file_contents(path): # snapshots directory should contain two snapshots self.assertListEqual(refs_contents, snapshots) - # snapshots_paths = [os.path.join(expected_path, s) for s in snapshots] + snapshot_links = [ + os.readlink( + os.path.join(expected_path, "snapshots", filename, "file_0.txt") + ) + for filename in snapshots + ] + + # All snapshot links should point to the same file. + self.assertEqual(*snapshot_links) + + +@with_production_testing +class CacheFileLayoutSnapshotDownload(unittest.TestCase): + def test_file_downloaded_in_cache(self): + with tempfile.TemporaryDirectory() as cache: + snapshot_download(MODEL_IDENTIFIER, cache_dir=cache) + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + snapshots.sort() + + # Directory should contain two revisions + self.assertListEqual(refs, ["main"]) + + ref_content = get_file_contents( + os.path.join(expected_path, "refs", refs[0]) + ) + + # snapshots directory should contain two snapshots + self.assertListEqual([ref_content], snapshots) + + snapshot_path = os.path.join(expected_path, "snapshots", snapshots[0]) + + files_in_snapshot = os.listdir(snapshot_path) + + snapshot_links = [ + os.readlink(os.path.join(snapshot_path, filename)) + for filename in files_in_snapshot + ] + + resolved_snapshot_links = [ + os.path.normpath(os.path.join(snapshot_path, link)) + for link in snapshot_links + ] + + self.assertTrue(all([os.path.isfile(l) for l in resolved_snapshot_links])) + + def test_file_downloaded_in_cache_several_revisions(self): + with tempfile.TemporaryDirectory() as cache: + snapshot_download(MODEL_IDENTIFIER, cache_dir=cache, revision="file-3") + snapshot_download(MODEL_IDENTIFIER, cache_dir=cache, revision="file-2") + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + refs.sort() + + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + snapshots.sort() + + # Directory should contain two revisions + self.assertListEqual(refs, ["file-2", "file-3"]) + + refs_content = [ + get_file_contents(os.path.join(expected_path, "refs", ref)) + for ref in refs + ] + refs_content.sort() + + # snapshots directory should contain two snapshots + self.assertListEqual(refs_content, snapshots) + + snapshots_paths = [ + os.path.join(expected_path, "snapshots", s) for s in snapshots + ] + + files_in_snapshots = {s: os.listdir(s) for s in snapshots_paths} + links_in_snapshots = { + k: [os.readlink(os.path.join(k, _v)) for _v in v] + for k, v in files_in_snapshots.items() + } + + resolved_snapshots_links = { + k: [os.path.normpath(os.path.join(k, link)) for link in v] + for k, v in links_in_snapshots.items() + } + + all_links = [b for a in resolved_snapshots_links.values() for b in a] + all_unique_links = set(all_links) + + # [ 100] . + # ├── [ 140] blobs + # │ ├── [ 7] 4475433e279a71203927cbe80125208a3b5db560 + # │ ├── [ 7] 50fcd26d6ce3000f9d5f12904e80eccdc5685dd1 + # │ ├── [ 7] 80146afc836c60e70ba67933fec439ab05b478f6 + # │ ├── [ 7] 8cf9e18f080becb674b31c21642538269fe886a4 + # │ └── [1.1K] ac481c8eb05e4d2496fbe076a38a7b4835dd733d + # ├── [ 80] refs + # │ ├── [ 40] file-2 + # │ └── [ 40] file-3 + # └── [ 80] snapshots + # ├── [ 120] 5e23cb3ae7f904919a442e1b27dcddae6c6bc292 + # │ ├── [ 52] file_0.txt -> ../../blobs/80146afc836c60e70ba67933fec439ab05b478f6 + # │ ├── [ 52] file_1.txt -> ../../blobs/50fcd26d6ce3000f9d5f12904e80eccdc5685dd1 + # │ ├── [ 52] file_2.txt -> ../../blobs/4475433e279a71203927cbe80125208a3b5db560 + # │ └── [ 52] .gitattributes -> ../../blobs/ac481c8eb05e4d2496fbe076a38a7b4835dd733d + # └── [ 120] 78aa2ebdb60bba086496a8792ba506e58e587b4c + # ├── [ 52] file_0.txt -> ../../blobs/80146afc836c60e70ba67933fec439ab05b478f6 + # ├── [ 52] file_1.txt -> ../../blobs/50fcd26d6ce3000f9d5f12904e80eccdc5685dd1 + # ├── [ 52] file_3.txt -> ../../blobs/8cf9e18f080becb674b31c21642538269fe886a4 + # └── [ 52] .gitattributes -> ../../blobs/ac481c8eb05e4d2496fbe076a38a7b4835dd733d + + # Across the two revisions, there should be 8 total links + self.assertEqual(len(all_links), 8) + + # Across the two revisions, there should only be 5 unique files. + self.assertEqual(len(all_unique_links), 5) + + +class ReferenceUpdates(unittest.TestCase): + _api = HfApi(endpoint=ENDPOINT_STAGING) + + @classmethod + def setUpClass(cls): + """ + Share this valid token in all tests below. + """ + cls._token = TOKEN + cls._api.set_access_token(TOKEN) + + def test_update_reference(self): + repo_id = f"{USER}/hfh-cache-layout" + create_repo(repo_id, token=self._token, exist_ok=True) + + try: + upload_file( + path_or_fileobj=BytesIO(b"Some string"), + path_in_repo="file.txt", + repo_id=repo_id, + token=self._token, + ) + + with tempfile.TemporaryDirectory() as cache: + hf_hub_download(repo_id, "file.txt", cache_dir=cache) + + expected_directory_name = f'models--{repo_id.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + + # Directory should contain two revisions + self.assertListEqual(refs, ["main"]) + + initial_ref_content = get_file_contents( + os.path.join(expected_path, "refs", refs[0]) + ) + + # Upload a new file on the same branch + upload_file( + path_or_fileobj=BytesIO(b"Some new string"), + path_in_repo="file.txt", + repo_id=repo_id, + token=self._token, + ) + + hf_hub_download(repo_id, "file.txt", cache_dir=cache) + + final_ref_content = get_file_contents( + os.path.join(expected_path, "refs", refs[0]) + ) + + # The `main` reference should point to two different, but existing snapshots which contain + # a 'file.txt' + self.assertNotEqual(initial_ref_content, final_ref_content) + self.assertTrue( + os.path.isdir( + os.path.join(expected_path, "snapshots", initial_ref_content) + ) + ) + self.assertTrue( + os.path.isfile( + os.path.join( + expected_path, "snapshots", initial_ref_content, "file.txt" + ) + ) + ) + self.assertTrue( + os.path.isdir( + os.path.join(expected_path, "snapshots", final_ref_content) + ) + ) + self.assertTrue( + os.path.isfile( + os.path.join( + expected_path, "snapshots", final_ref_content, "file.txt" + ) + ) + ) + except Exception: + raise + finally: + delete_repo(repo_id, token=self._token) From 2d54dd1763b9ca1d0a21fee7b6676a06dc397336 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 24 May 2022 20:57:54 +0200 Subject: [PATCH 37/47] :wolf: Fix for bug reported by @thomwolf see https://github.com/huggingface/huggingface_hub/pull/801#issuecomment-1134576435 --- src/huggingface_hub/file_download.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index d95e7fcdb0..01418b34fb 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -901,6 +901,9 @@ def hf_hub_download( if isinstance(cache_dir, Path): cache_dir = str(cache_dir) + if subfolder is not None: + filename = f"{subfolder}/{filename}" + if repo_type is None: repo_type = "model" if repo_type not in REPO_TYPES: @@ -926,9 +929,7 @@ def hf_hub_download( if os.path.exists(pointer_path): return pointer_path - url = hf_hub_url( - repo_id, filename, subfolder=subfolder, repo_type=repo_type, revision=revision - ) + url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision) headers = { "user-agent": http_user_agent( From 8f71ad6681cf88fd1af36f4cb5d67e01f2eee33d Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 24 May 2022 16:57:33 -0400 Subject: [PATCH 38/47] Special case for Windows --- example-transformers-tf.py | 22 ---------------------- example.py | 21 --------------------- src/huggingface_hub/file_download.py | 16 ++++++++++++++-- 3 files changed, 14 insertions(+), 45 deletions(-) delete mode 100644 example-transformers-tf.py delete mode 100644 example.py diff --git a/example-transformers-tf.py b/example-transformers-tf.py deleted file mode 100644 index a7fceb9cb2..0000000000 --- a/example-transformers-tf.py +++ /dev/null @@ -1,22 +0,0 @@ -from huggingface_hub.snapshot_download import snapshot_download -from huggingface_hub.utils.logging import set_verbosity_debug -from transformers import AutoModelForMaskedLM, TFAutoModelForMaskedLM - - -set_verbosity_debug() - -DISTILBERT = "distilbert-base-uncased" - -folder_path = snapshot_download( - repo_id=DISTILBERT, - repo_type="model", -) - - -print("The whole model repo has been saved to", folder_path) - -pt_model = AutoModelForMaskedLM.from_pretrained(folder_path) -tf_model = TFAutoModelForMaskedLM.from_pretrained(folder_path) -# Yay it works! - -print() diff --git a/example.py b/example.py deleted file mode 100644 index 77cd6c34ca..0000000000 --- a/example.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -from huggingface_hub.file_download import hf_hub_download - -OLDER_REVISION = "bbc77c8132af1cc5cf678da3f1ddf2de43606d48" - -hf_hub_download("julien-c/EsperBERTo-small", filename="README.md") - -hf_hub_download("julien-c/EsperBERTo-small", filename="pytorch_model.bin") - -hf_hub_download( - "julien-c/EsperBERTo-small", filename="README.md", revision=OLDER_REVISION -) - -weights_file = hf_hub_download( - "julien-c/EsperBERTo-small", filename="pytorch_model.bin", revision=OLDER_REVISION -) - -w = torch.load(weights_file, map_location=torch.device("cpu")) -# Yay it works! just loaded a torch file from a symlink - -print() diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 01418b34fb..a7325ae110 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -33,7 +33,7 @@ REPO_TYPES_URL_PREFIXES, ) from .hf_api import HfFolder -from .utils import logging +from .utils import logging, run_subprocess from .utils._deprecation import _deprecate_positional_args @@ -740,7 +740,19 @@ def _create_relative_symlink(src: str, dst: str) -> None: os.remove(dst) except OSError: pass - os.symlink(relative_src, dst) + try: + os.symlink(relative_src, dst) + except OSError: + # Likely running on Windows + if os.name == "nt": + raise OSError( + "Windows requires Developer Mode to be activated, or to run Python as " + "an administrator, in order to create symlinks.\nIn order to " + "activate Developer Mode, see this article: " + "https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development" + ) + else: + raise def repo_folder_name( From 442b5b17ab461f166f639e7b85638d3b3022c4ca Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 24 May 2022 19:16:29 -0400 Subject: [PATCH 39/47] Address comments and docs --- .../package_reference/file_download.mdx | 81 +++++++++++++++++++ src/huggingface_hub/file_download.py | 12 ++- tests/test_cache_layout.py | 11 ++- 3 files changed, 100 insertions(+), 4 deletions(-) diff --git a/docs/source/package_reference/file_download.mdx b/docs/source/package_reference/file_download.mdx index 79d1063ce4..d8862869bd 100644 --- a/docs/source/package_reference/file_download.mdx +++ b/docs/source/package_reference/file_download.mdx @@ -8,3 +8,84 @@ [[autodoc]] huggingface_hub.hf_hub_url +## Caching + +The methods displayed above are designed to work with a caching system that prevents re-downloading files. +The caching system was updated in v0.8.0 to allow directory structure and file sharing across +libraries that depend on the hub. + +The caching system is designed as follows: + +``` + +├─ +├─ +├─ +``` + +Models, datasets and spaces share a common root. Each of these repositories contains the namespace +(organization, username) if it exists, alongside the repository name: + +``` + +├─ models--julien-c--EsperBERTo-small +├─ models--lysandrejik--arxiv-nlp +├─ models--bert-base-cased +├─ datasets--glue +├─ datasets--huggingface--DataMeasurementsFiles +├─ spaces--dalle-mini--dalle-mini +``` + +It is within these folders that all files will now be downloaded from the hub. Caching ensures that +a file isn't downloaded twice if it already exists and wasn't updated; but if it was updated, +and you're asking for the latest file, then it will download the latest file (while keeping +the previous file intact in case you need it again). + +In order to achieve this, all folders contain the same skeleton: + +``` + +├─ datasets--glue +│ ├─ refs +│ ├─ blobs +│ ├─ snapshots +... +``` + +Each folder is designed to contain the following: + +### Refs + +The `refs` folder contains files which indicates the latest revision of the given reference. For example, +if we have previously fetched a file from the `main` branch of a repository, the `refs` +folder will contain a file named `main`, which will itself contain the commit identifier of the current head. + +If the latest commit of `main` had `aaaaaa` as identifier, then it would contain `aaaaaa`. + +If that same branch gets updated with a new commit, that has `bbbbbb` as an identifier, then +redownloading a file from that reference will update the `refs/main` file to contain `bbbbbb`. + +### Blobs + +The `blobs` folder contains the actual files that we have downloaded. The name of each file is their hash. + +### Snapshots + +The `snapshots` folder contains symlinks to the blobs mentioned above. It is itself made up of several folders: +one per known revision! + +In the explanation above, we had initially fetched a file from the `aaaaaa` revision, before fetching a file from +the `bbbbbb` revision. In this situation, we would now have two folders in the `snapshots` folder: `aaaaaa` +and `bbbbbb`. + +In each of these folders, live symlinks that have the names of the files that we have downloaded. For example, +if we had downloaded the `READMD.md` file at revision `aaaaaa`, we would have the following path: + +``` +//snapshots/aaaaaa/README.md +``` + +That `README.md` file is actually a symlink linking to the blob that has the hash of the file. + +Creating the skeleton this way means opens up the mechanism to file sharing: if the same file was fetched in +revision `bbbbbb`, it would have the same hash and the file would not need to be redownloaded. \ No newline at end of file diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index a7325ae110..3dd3b45cf2 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -33,7 +33,7 @@ REPO_TYPES_URL_PREFIXES, ) from .hf_api import HfFolder -from .utils import logging, run_subprocess +from .utils import logging from .utils._deprecation import _deprecate_positional_args @@ -786,6 +786,7 @@ def hf_hub_download( cache_dir: Union[str, Path, None] = None, user_agent: Union[Dict, str, None] = None, force_download: Optional[bool] = False, + force_filename: Optional[str] = None, proxies: Optional[Dict] = None, etag_timeout: Optional[float] = 10, resume_download: Optional[bool] = False, @@ -883,6 +884,14 @@ def hf_hub_download( """ + if force_filename is not None: + warnings.warn( + "The `force_filename` parameter is deprecated as a new caching system, " + "which keeps the filenames as they are on the Hub, is now in place.", + FutureWarning, + ) + legacy_cache_layout = True + if legacy_cache_layout: url = hf_hub_url( repo_id, @@ -899,6 +908,7 @@ def hf_hub_download( cache_dir=cache_dir, user_agent=user_agent, force_download=force_download, + force_filename=force_filename, proxies=proxies, etag_timeout=etag_timeout, resume_download=resume_download, diff --git a/tests/test_cache_layout.py b/tests/test_cache_layout.py index fd9a39d56d..442fa6a5af 100644 --- a/tests/test_cache_layout.py +++ b/tests/test_cache_layout.py @@ -2,6 +2,7 @@ import tempfile import time import unittest +import uuid from io import BytesIO from huggingface_hub import ( @@ -22,6 +23,10 @@ MODEL_IDENTIFIER = "hf-internal-testing/hfh-cache-layout" +def repo_name(id=uuid.uuid4().hex[:6]): + return "repo-{0}-{1}".format(id, int(time.time() * 10e3)) + + def get_file_contents(path): with open(path) as f: content = f.read() @@ -69,7 +74,7 @@ def test_file_downloaded_in_cache(self): ) with open(resolved_blob_absolute) as f: - blob_contents = f.readline().strip() + blob_contents = f.read().strip() # The contents of the file should be 'File 0'. self.assertEqual(blob_contents, "File 0") @@ -92,7 +97,7 @@ def test_file_downloaded_in_cache_with_revision(self): self.assertListEqual(refs, [expected_reference]) with open(os.path.join(expected_path, "refs", expected_reference)) as f: - snapshot_name = f.readline().strip() + snapshot_name = f.read().strip() # The `main` reference should point to the only snapshot we have downloaded self.assertListEqual(snapshots, [snapshot_name]) @@ -312,7 +317,7 @@ def setUpClass(cls): cls._api.set_access_token(TOKEN) def test_update_reference(self): - repo_id = f"{USER}/hfh-cache-layout" + repo_id = f"{USER}/{repo_name()}" create_repo(repo_id, token=self._token, exist_ok=True) try: From e1ff4eb90283459a74ec9b52b5f78b815680e540 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 25 May 2022 08:09:48 -0400 Subject: [PATCH 40/47] Clean up with ternary cc @julien-c --- src/huggingface_hub/_snapshot_download.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 8da36594a4..04b75fca73 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -158,26 +158,20 @@ def snapshot_download( # If the specified revision is a commit hash, look inside "snapshots". # If the specified revision is a branch or tag, look inside "refs". if local_files_only: - commit_hash = revision - if not REGEX_COMMIT_HASH.match(commit_hash): - # rertieve commit_hash from file + + def resolve_ref(revision) -> str: + # retrieve commit_hash from file ref_path = os.path.join(storage_folder, "refs", revision) with open(ref_path) as f: - commit_hash = f.read() + return f.read() + commit_hash = ( + revision if REGEX_COMMIT_HASH.match(revision) else resolve_ref(revision) + ) snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + if os.path.exists(snapshot_folder): return snapshot_folder - snapshot_folder = os.path.join(storage_folder, "snapshots", revision) - if os.path.exists(snapshot_folder): - return snapshot_folder - else: - ref_path = os.path.join(storage_folder, "refs", revision) - with open(ref_path) as f: - commit_hash = f.read() - snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) - if os.path.exists(snapshot_folder): - return snapshot_folder raise ValueError( "Cannot find an appropriate cached snapshot folder for the specified" From e69ef4fd05622074ce5c5d989aea4b59963b28b5 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 25 May 2022 08:30:29 -0400 Subject: [PATCH 41/47] Add argument to `cached_download` --- src/huggingface_hub/file_download.py | 18 ++++++++++++----- tests/test_file_download.py | 30 +++++++++++++++++----------- tests/test_hf_api.py | 16 +++++++++++---- tests/test_snapshot_download.py | 8 -------- 4 files changed, 43 insertions(+), 29 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 3dd3b45cf2..6a763e93c3 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -467,6 +467,7 @@ def cached_download( resume_download: Optional[bool] = False, use_auth_token: Union[bool, str, None] = None, local_files_only: Optional[bool] = False, + legacy_cache_layout: Optional[bool] = False, ) -> Optional[str]: # pragma: no cover """ Download from a given URL and cache it if it's not already present in the @@ -508,6 +509,11 @@ def cached_download( local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. + legacy_cache_layout (`bool`, *optional*, defaults to `False`): + Set this parameter to `True` to mention that you'd like to continue + the old cache layout. Putting this to `True` manually will not raise + any warning when using `cached_download`. We recommend using + `hf_hub_download` to take advantage of the new cache. Returns: Local path (string) of file or if networking is off, last version of @@ -526,11 +532,13 @@ def cached_download( """ - warnings.warn( - "`cached_download` is the legacy way to download files from the HF hub, please" - " consider upgrading to `hf_hub_download`", - FutureWarning, - ) + if not legacy_cache_layout: + warnings.warn( + "`cached_download` is the legacy way to download files from the HF hub," + " please consider upgrading to `hf_hub_download`", + FutureWarning, + ) + if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE if isinstance(cache_dir, Path): diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 98fd055d56..4cf20c1c78 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -57,7 +57,7 @@ class CachedDownloadTests(unittest.TestCase): def test_bogus_url(self): url = "https://bogus" with self.assertRaisesRegex(ValueError, "Connection error"): - _ = cached_download(url) + _ = cached_download(url, legacy_cache_layout=True) def test_no_connection(self): invalid_url = hf_hub_url( @@ -68,20 +68,26 @@ def test_no_connection(self): valid_url = hf_hub_url( DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT ) - self.assertIsNotNone(cached_download(valid_url, force_download=True)) + self.assertIsNotNone( + cached_download(valid_url, force_download=True, legacy_cache_layout=True) + ) for offline_mode in OfflineSimulationMode: with offline(mode=offline_mode): with self.assertRaisesRegex(ValueError, "Connection error"): - _ = cached_download(invalid_url) + _ = cached_download(invalid_url, legacy_cache_layout=True) with self.assertRaisesRegex(ValueError, "Connection error"): - _ = cached_download(valid_url, force_download=True) - self.assertIsNotNone(cached_download(valid_url)) + _ = cached_download( + valid_url, force_download=True, legacy_cache_layout=True + ) + self.assertIsNotNone( + cached_download(valid_url, legacy_cache_layout=True) + ) def test_file_not_found(self): # Valid revision (None) but missing file. url = hf_hub_url(DUMMY_MODEL_ID, filename="missing.bin") with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): - _ = cached_download(url) + _ = cached_download(url, legacy_cache_layout=True) def test_revision_not_found(self): # Valid file but missing revision @@ -91,13 +97,13 @@ def test_revision_not_found(self): revision=DUMMY_MODEL_ID_REVISION_INVALID, ) with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): - _ = cached_download(url) + _ = cached_download(url, legacy_cache_layout=True) def test_standard_object(self): url = hf_hub_url( DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT ) - filepath = cached_download(url, force_download=True) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) metadata = filename_to_url(filepath) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')) @@ -108,7 +114,7 @@ def test_standard_object_rev(self): filename=CONFIG_NAME, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) metadata = filename_to_url(filepath) self.assertNotEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"') # Caution: check that the etag is *not* equal to the one from `test_standard_object` @@ -117,7 +123,7 @@ def test_lfs_object(self): url = hf_hub_url( DUMMY_MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT ) - filepath = cached_download(url, force_download=True) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) metadata = filename_to_url(filepath) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA256}"')) @@ -136,7 +142,7 @@ def test_dataset_standard_object_rev(self): ) self.assertEqual(url, url2) # now let's download - filepath = cached_download(url, force_download=True) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) metadata = filename_to_url(filepath) self.assertNotEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"') @@ -147,7 +153,7 @@ def test_dataset_lfs_object(self): repo_type=REPO_TYPE_DATASET, revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) metadata = filename_to_url(filepath) self.assertEqual( metadata, diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 64e48b3486..c64cf6c4f2 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -511,7 +511,9 @@ def test_upload_file_path(self): user=USER, repo=REPO_NAME, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download( + url, force_download=True, legacy_cache_layout=True + ) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) @@ -538,7 +540,9 @@ def test_upload_file_fileobj(self): user=USER, repo=REPO_NAME, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download( + url, force_download=True, legacy_cache_layout=True + ) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) @@ -565,7 +569,9 @@ def test_upload_file_bytesio(self): user=USER, repo=REPO_NAME, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download( + url, force_download=True, legacy_cache_layout=True + ) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, filecontent.getvalue().decode()) @@ -648,7 +654,9 @@ def test_upload_buffer(self): user=USER, repo=REPO_NAME, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download( + url, force_download=True, legacy_cache_layout=True + ) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) diff --git a/tests/test_snapshot_download.py b/tests/test_snapshot_download.py index b41b438a3c..23d2947238 100644 --- a/tests/test_snapshot_download.py +++ b/tests/test_snapshot_download.py @@ -275,14 +275,6 @@ def test_download_model_local_only_multiple(self): cache_dir=tmpdirname, ) - # now load from cache and make sure warning to be raised - with self.assertWarns(Warning): - snapshot_download( - f"{USER}/{REPO_NAME}", - cache_dir=tmpdirname, - local_files_only=True, - ) - # cache multiple commits and make sure correct commit is taken with tempfile.TemporaryDirectory() as tmpdirname: # first download folder to cache it From 0d06b5577e6b434e3b3dee10db2fc7afa37de532 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 25 May 2022 08:43:44 -0400 Subject: [PATCH 42/47] Opt-in for filename_to-url --- src/huggingface_hub/file_download.py | 26 +++++++++++++++++++++----- tests/test_file_download.py | 10 +++++----- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 6a763e93c3..86dd9df54e 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -263,15 +263,31 @@ def url_to_filename(url: str, etag: Optional[str] = None) -> str: return filename -def filename_to_url(filename, cache_dir=None) -> Tuple[str, str]: +def filename_to_url( + filename, + cache_dir: Optional[str] = None, + legacy_cache_layout: Optional[bool] = False, +) -> Tuple[str, str]: """ Return the url and etag (which may be `None`) stored for `filename`. Raise `EnvironmentError` if `filename` or its stored metadata do not exist. + + Args: + filename (`str`): + The name of the file + cache_dir (`str`, *optional*): + The cache directory to use instead of the default one. + legacy_cache_layout (`bool`, *optional*, defaults to `False`): + If `True`, uses the legacy file cache layout i.e. just call `hf_hub_url` + then `cached_download`. This is deprecated as the new cache layout is + more powerful. """ - warnings.warn( - "`filename_to_url` uses the legacy way cache file layout", - FutureWarning, - ) + if not legacy_cache_layout: + warnings.warn( + "`filename_to_url` uses the legacy way cache file layout", + FutureWarning, + ) + if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE if isinstance(cache_dir, Path): diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 4cf20c1c78..6461d2250b 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -104,7 +104,7 @@ def test_standard_object(self): DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT ) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) - metadata = filename_to_url(filepath) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')) def test_standard_object_rev(self): @@ -115,7 +115,7 @@ def test_standard_object_rev(self): revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, ) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) - metadata = filename_to_url(filepath) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertNotEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"') # Caution: check that the etag is *not* equal to the one from `test_standard_object` @@ -124,7 +124,7 @@ def test_lfs_object(self): DUMMY_MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT ) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) - metadata = filename_to_url(filepath) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA256}"')) def test_dataset_standard_object_rev(self): @@ -143,7 +143,7 @@ def test_dataset_standard_object_rev(self): self.assertEqual(url, url2) # now let's download filepath = cached_download(url, force_download=True, legacy_cache_layout=True) - metadata = filename_to_url(filepath) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertNotEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"') def test_dataset_lfs_object(self): @@ -154,7 +154,7 @@ def test_dataset_lfs_object(self): revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT, ) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) - metadata = filename_to_url(filepath) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual( metadata, (url, '"95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"'), From d6bb92e118b294df3c2b30adef679d01f52c32c7 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 25 May 2022 08:54:51 -0400 Subject: [PATCH 43/47] Opt-in for filename_to-url --- tests/test_file_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 6461d2250b..071478c0a8 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -168,5 +168,5 @@ def test_hf_hub_download_legacy(self): force_download=True, legacy_cache_layout=True, ) - metadata = filename_to_url(filepath) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"') From 425964272a2e3918f913eb9a68c1af87bb2af746 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 25 May 2022 09:31:45 -0400 Subject: [PATCH 44/47] Pass the flag --- src/huggingface_hub/file_download.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 86dd9df54e..807543c96b 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -938,6 +938,7 @@ def hf_hub_download( resume_download=resume_download, use_auth_token=use_auth_token, local_files_only=local_files_only, + legacy_cache_layout=legacy_cache_layout, ) if cache_dir is None: From a8348ab6de6f82efb9287f9220e5843de6073bde Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 25 May 2022 10:11:02 -0400 Subject: [PATCH 45/47] Update docs/source/package_reference/file_download.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/package_reference/file_download.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/package_reference/file_download.mdx b/docs/source/package_reference/file_download.mdx index d8862869bd..a6bdd491fd 100644 --- a/docs/source/package_reference/file_download.mdx +++ b/docs/source/package_reference/file_download.mdx @@ -60,7 +60,7 @@ The `refs` folder contains files which indicates the latest revision of the give if we have previously fetched a file from the `main` branch of a repository, the `refs` folder will contain a file named `main`, which will itself contain the commit identifier of the current head. -If the latest commit of `main` had `aaaaaa` as identifier, then it would contain `aaaaaa`. +If the latest commit of `main` has `aaaaaa` as identifier, then it will contain `aaaaaa`. If that same branch gets updated with a new commit, that has `bbbbbb` as an identifier, then redownloading a file from that reference will update the `refs/main` file to contain `bbbbbb`. From 8c599ee123ef41b7585d4369091800f09dd18fbf Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 25 May 2022 10:21:48 -0400 Subject: [PATCH 46/47] Update src/huggingface_hub/file_download.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/huggingface_hub/file_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 807543c96b..1a78f0e7fe 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -887,7 +887,7 @@ def hf_hub_download( If `True`, avoid downloading the file and return the path to the local cached file if it exists. legacy_cache_layout (`bool`, *optional*, defaults to `False`): - If `True`, uses the legacy file cache layout i.e. just call `hf_hub_url` + If `True`, uses the legacy file cache layout i.e. just call [`hf_hub_url`] then `cached_download`. This is deprecated as the new cache layout is more powerful. From 3dfc00c6d457246c7a97564f4ef92588b4aaa507 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 25 May 2022 10:22:23 -0400 Subject: [PATCH 47/47] Address review comments --- .../package_reference/file_download.mdx | 27 ++++++++++++++++++- src/huggingface_hub/_snapshot_download.py | 9 +++---- src/huggingface_hub/file_download.py | 12 +++------ 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/docs/source/package_reference/file_download.mdx b/docs/source/package_reference/file_download.mdx index a6bdd491fd..1a05dfdf9d 100644 --- a/docs/source/package_reference/file_download.mdx +++ b/docs/source/package_reference/file_download.mdx @@ -23,6 +23,9 @@ The caching system is designed as follows: ├─ ``` +The `` is usually your user's home directory. However, it is customizable with the +`cache_dir` argument on all methods, or by specifying the `HF_HOME` environment variable. + Models, datasets and spaces share a common root. Each of these repositories contains the namespace (organization, username) if it exists, alongside the repository name: @@ -88,4 +91,26 @@ if we had downloaded the `READMD.md` file at revision `aaaaaa`, we would have th That `README.md` file is actually a symlink linking to the blob that has the hash of the file. Creating the skeleton this way means opens up the mechanism to file sharing: if the same file was fetched in -revision `bbbbbb`, it would have the same hash and the file would not need to be redownloaded. \ No newline at end of file +revision `bbbbbb`, it would have the same hash and the file would not need to be redownloaded. + +### In practice + +In practice, it should look like the following tree in your cache: + +``` + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd +``` \ No newline at end of file diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 04b75fca73..df65e4f2f9 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -159,15 +159,14 @@ def snapshot_download( # If the specified revision is a branch or tag, look inside "refs". if local_files_only: - def resolve_ref(revision) -> str: + if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + else: # retrieve commit_hash from file ref_path = os.path.join(storage_folder, "refs", revision) with open(ref_path) as f: - return f.read() + commit_hash = f.read() - commit_hash = ( - revision if REGEX_COMMIT_HASH.match(revision) else resolve_ref(revision) - ) snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) if os.path.exists(snapshot_folder): diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 1a78f0e7fe..c3e602f8d0 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -779,21 +779,14 @@ def _create_relative_symlink(src: str, dst: str) -> None: raise -def repo_folder_name( - *, - repo_id: str, - repo_type: str, -) -> str: +def repo_folder_name(*, repo_id: str, repo_type: str) -> str: """Return a serialized version of a hf.co repo name and type, safe for disk storage as a single non-nested folder. Example: models--julien-c--EsperBERTo-small """ # remove all `/` occurrences to correctly convert repo to directory name - parts = [ - f"{repo_type}s", - *repo_id.split("/"), - ] + parts = [f"{repo_type}s", *repo_id.split("/")] return REPO_ID_SEPARATOR.join(parts) @@ -949,6 +942,7 @@ def hf_hub_download( cache_dir = str(cache_dir) if subfolder is not None: + # This is used to create a URL, and not a local path, hence the forward slash. filename = f"{subfolder}/{filename}" if repo_type is None: