Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix snapshot download when local_dir is provided. #2592

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@
from .errors import GatedRepoError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
from .utils import (
OfflineModeIsEnabled,
filter_repo_objects,
logging,
validate_hf_hub_args,
)
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
from .utils import tqdm as hf_tqdm


Expand Down Expand Up @@ -191,6 +186,7 @@ def snapshot_download(
# => let's look if we can find the appropriate folder in the cache:
# - if the specified revision is a commit hash, look inside "snapshots".
# - f the specified revision is a branch or tag, look inside "refs".
# => if local_dir is not None, we will return the path to the local folder if it exists.
if repo_info is None:
# Try to get which commit hash corresponds to the specified revision
commit_hash = None
Expand All @@ -210,7 +206,12 @@ def snapshot_download(
# Snapshot folder exists => let's return it
# (but we can't check if all the files are actually there)
return snapshot_folder

# If local_dir is not None, return it if it exists and is not empty
if local_dir is not None:
local_dir = Path(local_dir)
if local_dir.is_dir() and any(local_dir.iterdir()):
logger.warning(f"Returning existing local_dir `{local_dir}`as it exists and is not empty.")
hanouticelina marked this conversation as resolved.
Show resolved Hide resolved
return str(local_dir.resolve())
# If we couldn't find the appropriate folder on disk, raise an error.
if local_files_only:
raise LocalEntryNotFoundError(
Expand Down
31 changes: 31 additions & 0 deletions tests/test_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,37 @@ def test_download_model_local_only(self):
)
self.assertTrue(self.first_commit_hash in storage_folder) # has expected revision

# Test with local_dir
with SoftTemporaryDirectory() as tmpdir:
# first download folder to local_dir
snapshot_download(self.repo_id, local_dir=tmpdir)
# now load from local_dir
storage_folder = snapshot_download(self.repo_id, local_dir=tmpdir, local_files_only=True)
self.assertTrue(str(tmpdir) in storage_folder) # has expected revision
hanouticelina marked this conversation as resolved.
Show resolved Hide resolved

def test_download_model_offline_mode_not_in_local(self):
hanouticelina marked this conversation as resolved.
Show resolved Hide resolved
"""Test that an already downloaded folder is returned when there is a connection error"""
# first download folder to local_dir
with SoftTemporaryDirectory() as tmpdir:
snapshot_download(self.repo_id, local_dir=tmpdir)
# Check that the folder is returned when there is a connection error
for offline_mode in OfflineSimulationMode:
with offline(mode=offline_mode):
storage_folder = snapshot_download(self.repo_id, local_dir=tmpdir)
self.assertTrue(str(tmpdir) in storage_folder)

def test_download_model_offline_mode_not_in_local_dir(self):
"""Test when connection error but local_dir is empty."""
with SoftTemporaryDirectory() as tmpdir:
with self.assertRaises(LocalEntryNotFoundError):
snapshot_download(self.repo_id, local_dir=tmpdir, local_files_only=True)

for offline_mode in OfflineSimulationMode:
with offline(mode=offline_mode):
with SoftTemporaryDirectory() as tmpdir:
with self.assertRaises(LocalEntryNotFoundError):
snapshot_download(self.repo_id, local_dir=tmpdir)

def test_download_model_offline_mode_not_cached(self):
"""Test when connection error but cache is empty."""
with SoftTemporaryDirectory() as tmpdir:
Expand Down
Loading