Skip to content

Commit

Permalink
Fix relative symlinks in cache (#1390)
Browse files Browse the repository at this point in the history
* fix relative symlinks in cache

* fix test

* fix test on windows
  • Loading branch information
Wauplin committed Mar 13, 2023
1 parent fbc4ccf commit 507703a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 47 deletions.
67 changes: 31 additions & 36 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def are_symlinks_supported(cache_dir: Union[str, Path, None] = None) -> bool:
src_path.touch()
dst_path = Path(tmpdir) / "dummy_file_dst"

# Relative source path as in `_create_relative_symlink``
# Relative source path as in `_create_symlink``
relative_src = os.path.relpath(src_path, start=os.path.dirname(dst_path))
try:
os.symlink(relative_src, dst_path)
Expand Down Expand Up @@ -558,7 +558,7 @@ def cached_download(
token: Union[bool, str, None] = None,
local_files_only: bool = False,
legacy_cache_layout: bool = False,
) -> Optional[str]: # pragma: no cover
) -> str:
"""
Download from a given URL and cache it if it's not already present in the
local cache.
Expand Down Expand Up @@ -819,61 +819,56 @@ def _normalize_etag(etag: Optional[str]) -> Optional[str]:
return etag.strip('"')


def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None:
"""Create a symbolic link named dst pointing to src as a relative path to dst.
The relative part is mostly because it seems more elegant to the author.
def _create_symlink(src: str, dst: str, new_blob: bool = False) -> None:
"""Create a symbolic link named dst pointing to src as an absolute path.
The result layout looks something like
└── [ 128] snapshots
├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
│ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
│ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
│ ├── [ 52] README.md -> /path/to/cache/blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
│ └── [ 76] pytorch_model.bin -> /path/to/cache/blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
If symlinks cannot be created on this platform (most likely to be Windows), the
workaround is to avoid symlinks by having the actual file in `dst`. If it is a new
file (`new_blob=True`), we move it to `dst`. If it is not a new file
(`new_blob=False`), we don't know if the blob file is already referenced elsewhere.
To avoid breaking existing cache, the file is duplicated on the disk.
If symlinks cannot be created on this platform (most likely to be Windows), the workaround is to avoid symlinks by
having the actual file in `dst`. If it is a new file (`new_blob=True`), we move it to `dst`. If it is not a new file
(`new_blob=False`), we don't know if the blob file is already referenced elsewhere. To avoid breaking existing
cache, the file is duplicated on the disk.
In case symlinks are not supported, a warning message is displayed to the user once
when loading `huggingface_hub`. The warning message can be disable with the
`DISABLE_SYMLINKS_WARNING` environment variable.
In case symlinks are not supported, a warning message is displayed to the user once when loading `huggingface_hub`.
The warning message can be disable with the `DISABLE_SYMLINKS_WARNING` environment variable.
"""
try:
os.remove(dst)
except OSError:
pass

abs_src = os.path.abspath(os.path.expanduser(src))
abs_dst = os.path.abspath(os.path.expanduser(dst))

try:
_support_symlinks = are_symlinks_supported(
os.path.dirname(os.path.commonpath([os.path.realpath(src), os.path.realpath(dst)]))
)
_support_symlinks = are_symlinks_supported(os.path.dirname(os.path.commonpath([abs_src, abs_dst])))
except PermissionError:
# Permission error means src and dst are not in the same volume (e.g. destination path has been provided
# by the user via `local_dir`. Let's test symlink support there)
_support_symlinks = are_symlinks_supported(os.path.dirname(dst))
_support_symlinks = are_symlinks_supported(os.path.dirname(abs_dst))

if _support_symlinks:
logger.info(f"Creating pointer from {src} to {dst}")
logger.info(f"Creating pointer from {abs_src} to {abs_dst}")
try:
os.symlink(src, dst)
os.symlink(abs_src, abs_dst)
except FileExistsError:
if os.path.islink(dst) and os.path.realpath(dst) == os.path.realpath(src):
# `dst` already exists and is a symlink to the `src` blob. It is most
# likely that the file has been cached twice concurrently (exactly
# between `os.remove` and `os.symlink`). Do nothing.
if os.path.islink(abs_dst) and os.path.realpath(abs_dst) == os.path.realpath(abs_src):
# `abs_dst` already exists and is a symlink to the `abs_src` blob. It is most likely that the file has
# been cached twice concurrently (exactly between `os.remove` and `os.symlink`). Do nothing.
pass
else:
# Very unlikely to happen. Means a file `dst` has been created exactly
# between `os.remove` and `os.symlink` and is not a symlink to the `src`
# blob file. Raise exception.
# Very unlikely to happen. Means a file `dst` has been created exactly between `os.remove` and
# `os.symlink` and is not a symlink to the `abs_src` blob file. Raise exception.
raise
elif new_blob:
logger.info(f"Symlink not supported. Moving file from {src} to {dst}")
logger.info(f"Symlink not supported. Moving file from {abs_src} to {abs_dst}")
os.replace(src, dst)
else:
logger.info(f"Symlink not supported. Copying file from {src} to {dst}")
logger.info(f"Symlink not supported. Copying file from {abs_src} to {abs_dst}")
shutil.copyfile(src, dst)


Expand Down Expand Up @@ -926,7 +921,7 @@ def hf_hub_download(
token: Union[bool, str, None] = None,
local_files_only: bool = False,
legacy_cache_layout: bool = False,
):
) -> str:
"""Download a given file if it's not already present in the local cache.
The new cache file layout looks like this:
Expand Down Expand Up @@ -1258,7 +1253,7 @@ def hf_hub_download(
if local_dir is not None: # to local dir
return _to_local_dir(blob_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
else: # or in snapshot cache
_create_relative_symlink(blob_path, pointer_path, new_blob=False)
_create_symlink(blob_path, pointer_path, new_blob=False)
return pointer_path

# Prevent parallel downloads of the same file with a lock.
Expand Down Expand Up @@ -1313,7 +1308,7 @@ def _resumable_file_manager() -> Generator[io.BufferedWriter, None, None]:
if local_dir is None:
logger.info(f"Storing {url} in cache at {blob_path}")
_chmod_and_replace(temp_file.name, blob_path)
_create_relative_symlink(blob_path, pointer_path, new_blob=True)
_create_symlink(blob_path, pointer_path, new_blob=True)
else:
local_dir_filepath = os.path.join(local_dir, relative_filename)
os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True)
Expand All @@ -1325,7 +1320,7 @@ def _resumable_file_manager() -> Generator[io.BufferedWriter, None, None]:
logger.info(f"Storing {url} in cache at {blob_path}")
_chmod_and_replace(temp_file.name, blob_path)
logger.info("Create symlink to local dir")
_create_relative_symlink(blob_path, local_dir_filepath, new_blob=False)
_create_symlink(blob_path, local_dir_filepath, new_blob=False)
elif local_dir_use_symlinks == "auto" and not is_big_file:
logger.info(f"Storing {url} in cache at {blob_path}")
_chmod_and_replace(temp_file.name, blob_path)
Expand Down Expand Up @@ -1544,7 +1539,7 @@ def _to_local_dir(
use_symlinks = os.stat(real_blob_path).st_size > constants.HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD

if use_symlinks:
_create_relative_symlink(real_blob_path, local_dir_filepath, new_blob=False)
_create_symlink(real_blob_path, local_dir_filepath, new_blob=False)
else:
shutil.copyfile(real_blob_path, local_dir_filepath)
return local_dir_filepath
14 changes: 8 additions & 6 deletions src/huggingface_hub/repocard.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,19 @@ def load(
if Path(repo_id_or_path).exists():
card_path = Path(repo_id_or_path)
elif isinstance(repo_id_or_path, str):
card_path = hf_hub_download(
repo_id_or_path,
REPOCARD_NAME,
repo_type=repo_type or cls.repo_type,
token=token,
card_path = Path(
hf_hub_download(
repo_id_or_path,
REPOCARD_NAME,
repo_type=repo_type or cls.repo_type,
token=token,
)
)
else:
raise ValueError(f"Cannot load RepoCard: path not found on disk ({repo_id_or_path}).")

# Preserve newlines in the existing file.
with Path(card_path).open(mode="r", newline="", encoding="utf-8") as f:
with card_path.open(mode="r", newline="", encoding="utf-8") as f:
return cls(f.read(), ignore_metadata_errors=ignore_metadata_errors)

def validate(self, repo_type: Optional[str] = None):
Expand Down
29 changes: 24 additions & 5 deletions tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
import re
import shutil
import stat
import unittest
from pathlib import Path
Expand All @@ -30,7 +31,7 @@
)
from huggingface_hub.file_download import (
_CACHED_NO_EXIST,
_create_relative_symlink,
_create_symlink,
cached_download,
filename_to_url,
get_hf_file_metadata,
Expand Down Expand Up @@ -739,15 +740,15 @@ def test_hf_hub_download_on_awful_subfolder_and_filename(self):
class CreateSymlinkTest(unittest.TestCase):
@unittest.skipIf(os.name == "nt", "No symlinks on Windows")
@patch("huggingface_hub.file_download.are_symlinks_supported")
def test_create_relative_symlink_concurrent_access(self, mock_are_symlinks_supported: Mock) -> None:
def test_create_symlink_concurrent_access(self, mock_are_symlinks_supported: Mock) -> None:
with SoftTemporaryDirectory() as tmpdir:
src = os.path.join(tmpdir, "source")
other = os.path.join(tmpdir, "other")
dst = os.path.join(tmpdir, "destination")

# Normal case: symlink does not exist
mock_are_symlinks_supported.return_value = True
_create_relative_symlink(src, dst)
_create_symlink(src, dst)
self.assertEqual(os.path.realpath(dst), os.path.realpath(src))

# Symlink already exists when it tries to create it (most probably from a
Expand All @@ -757,7 +758,7 @@ def _are_symlinks_supported(cache_dir: str) -> bool:
return True

mock_are_symlinks_supported.side_effect = _are_symlinks_supported
_create_relative_symlink(src, dst)
_create_symlink(src, dst)

# Symlink already exists but pointing to a different source file. This should
# never happen in the context of HF cache system -> raise exception
Expand All @@ -767,7 +768,25 @@ def _are_symlinks_supported(cache_dir: str) -> bool:

mock_are_symlinks_supported.side_effect = _are_symlinks_supported
with self.assertRaises(FileExistsError):
_create_relative_symlink(src, dst)
_create_symlink(src, dst)

def test_create_symlink_relative_src(self) -> None:
"""Regression test for #1388.
See https://github.com/huggingface/huggingface_hub/issues/1388.
"""
# Test dir has to be relative
test_dir = Path(".") / "dir_for_create_symlink_test"
test_dir.mkdir(parents=True, exist_ok=True)
src = Path(test_dir) / "source"
src.touch()
dst = Path(test_dir) / "destination"

_create_symlink(str(src), str(dst))
self.assertTrue(dst.resolve().is_file())
if os.name != "nt":
self.assertEqual(dst.resolve(), src.resolve())
shutil.rmtree(test_dir)


def _recursive_chmod(path: str, mode: int) -> None:
Expand Down

0 comments on commit 507703a

Please sign in to comment.