Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented CLI download functionality #1617

Merged
merged 16 commits into from
Sep 6, 2023
Merged
157 changes: 157 additions & 0 deletions src/huggingface_hub/commands/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# coding=utf-8
# Copyright 2023-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains command to download files from the Hub with the CLI.

Usage:
huggingface-cli download --help

# Download file
huggingface-cli download gpt2 config.json

# Download full space quietly
huggingface-cli download jbilcke-hf/comic-factory --repo-type=space --quiet

# Download with filter
huggingface-cli download gpt2 --allow-patterns="*.safetensors"

# Download from revision
huggingface-cli download fffiloni/zeroscope --repo-type=space --revision=refs/pr/78

# Download with token
huggingface-cli download Wauplin/private-model --token=hf_***

TODO: add --to-local-dir (as `store_true` or as str path?)
"""
import warnings
from argparse import Namespace, _SubParsersAction
from typing import List, Optional

from huggingface_hub._snapshot_download import snapshot_download
from huggingface_hub.commands import BaseHuggingfaceCLICommand
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars


class DownloadCommand(BaseHuggingfaceCLICommand):
@staticmethod
def register_subcommand(parser: _SubParsersAction):
download_parser = parser.add_parser("download", help="Download files from the Hub")
download_parser.add_argument(
"repo_id", type=str, help="ID of the repo to download from (e.g. `username/repo-name`)."
)
download_parser.add_argument(
"filenames", type=str, nargs="*", help="Files to download (e.g. `config.json`, `data/metadata.jsonl`)."
)
download_parser.add_argument(
"--repo-type",
choices=["model", "dataset", "space"],
default="model",
help="Type of repo to download from (e.g. `dataset`).",
)
download_parser.add_argument(
"--revision",
type=str,
help="An optional Git revision id which can be a branch name, a tag, or a commit hash.",
)
download_parser.add_argument(
"--include", nargs="*", type=str, help="Glob patterns to match files to download."
)
download_parser.add_argument(
"--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to download."
)
download_parser.add_argument(
"--force-download",
action="store_true",
help="If True, the files will be downloaded even if they are already cached.",
)
download_parser.add_argument(
"--cache-dir", type=str, help="Path to the directory where to save the downloaded files."
)
download_parser.add_argument(
"--resume-download", action="store_true", help="If True, resume a previously interrupted download."
)
download_parser.add_argument(
"--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens"
)
download_parser.add_argument(
"--quiet",
action="store_true",
help="If True, progress bars are disabled and only the path to the download files is printed.",
)
download_parser.set_defaults(func=DownloadCommand)

def __init__(self, args: Namespace) -> None:
self.token = args.token
self.repo_id: str = args.repo_id
self.filenames: List[str] = args.filenames
self.repo_type: str = args.repo_type
self.revision: Optional[str] = args.revision
self.include: Optional[List[str]] = args.include
self.exclude: Optional[List[str]] = args.exclude
self.force_download: bool = args.force_download
self.resume_download: bool = args.resume_download
self.cache_dir: Optional[str] = args.cache_dir
self.quiet: bool = args.quiet

def run(self) -> None:
if self.quiet:
disable_progress_bars()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
print(self._download()) # Print path to downloaded files
enable_progress_bars()
else:
print(self._download()) # Print path to downloaded files

def _download(self) -> str:
# Warns user if patterns are ignored
if len(self.filenames) > 0:
if self.include is not None and len(self.include) > 0:
warnings.warn("Ignoring `--include` since filenames have being explicitly set.")
if self.exclude is not None and len(self.exclude) > 0:
warnings.warn("Ignoring `--exclude` since filenames have being explicitly set.")

# Single file to download: use `hf_hub_download`
if len(self.filenames) == 1:
return hf_hub_download(
repo_id=self.repo_id,
repo_type=self.repo_type,
revision=self.revision,
filename=self.filenames[0],
cache_dir=self.cache_dir,
resume_download=self.resume_download,
force_download=self.force_download,
token=self.token,
)

# Otherwise: use `snapshot_download` to ensure all files comes from same revision
elif len(self.filenames) == 0:
allow_patterns = self.include
ignore_patterns = self.exclude
else:
allow_patterns = self.filenames
ignore_patterns = None

return snapshot_download(
repo_id=self.repo_id,
repo_type=self.repo_type,
revision=self.revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
resume_download=self.resume_download,
force_download=self.force_download,
cache_dir=self.cache_dir,
token=self.token,
)
2 changes: 2 additions & 0 deletions src/huggingface_hub/commands/huggingface_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from argparse import ArgumentParser

from huggingface_hub.commands.delete_cache import DeleteCacheCommand
from huggingface_hub.commands.download import DownloadCommand
from huggingface_hub.commands.env import EnvironmentCommand
from huggingface_hub.commands.lfs import LfsCommands
from huggingface_hub.commands.scan_cache import ScanCacheCommand
Expand All @@ -32,6 +33,7 @@ def main():
LfsCommands.register_subcommand(commands_parser)
ScanCacheCommand.register_subcommand(commands_parser)
DeleteCacheCommand.register_subcommand(commands_parser)
DownloadCommand.register_subcommand(commands_parser)

# Let's go
args = parser.parse_args()
Expand Down
210 changes: 205 additions & 5 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import unittest
from argparse import ArgumentParser
import warnings
from argparse import ArgumentParser, Namespace
from unittest.mock import Mock, patch

from huggingface_hub.commands.delete_cache import DeleteCacheCommand
from huggingface_hub.commands.download import DownloadCommand
from huggingface_hub.commands.scan_cache import ScanCacheCommand
from huggingface_hub.utils import capture_output

from .testing_utils import (
DUMMY_MODEL_ID,
)

class TestCLI(unittest.TestCase):

class TestCacheCommand(unittest.TestCase):
def setUp(self) -> None:
"""
Set up CLI as in `src/huggingface_hub/commands/huggingface_cli.py`.

TODO: add other subcommands.
Set up scan-cache/delete-cache commands as in `src/huggingface_hub/commands/huggingface_cli.py`.
"""
self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli <command> [<args>]")
commands_parser = self.parser.add_subparsers()
Expand Down Expand Up @@ -50,3 +56,197 @@ def test_delete_cache_with_dir(self) -> None:
args = self.parser.parse_args(["delete-cache", "--dir", "something"])
self.assertEqual(args.dir, "something")
self.assertEqual(args.func, DeleteCacheCommand)


class TestDownloadCommand(unittest.TestCase):
def setUp(self) -> None:
"""
Set up CLI as in `src/huggingface_hub/commands/huggingface_cli.py`.
"""
self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli <command> [<args>]")
commands_parser = self.parser.add_subparsers()
DownloadCommand.register_subcommand(commands_parser)

def test_download_basic(self) -> None:
"""Test `huggingface-cli download dummy-repo`."""
args = self.parser.parse_args(["download", DUMMY_MODEL_ID])
self.assertEqual(args.repo_id, DUMMY_MODEL_ID)
self.assertEqual(len(args.filenames), 0)
self.assertEqual(args.repo_type, "model")
self.assertEqual(args.revision, None)
self.assertEqual(args.include, None)
self.assertEqual(args.exclude, None)
self.assertEqual(args.force_download, False)
self.assertEqual(args.cache_dir, None)
self.assertEqual(args.resume_download, False)
self.assertEqual(args.token, None)
self.assertEqual(args.quiet, False)
self.assertEqual(args.func, DownloadCommand)

def test_download_with_all_options(self) -> None:
"""Test `huggingface-cli download dummy-repo` with all options selected."""
args = self.parser.parse_args(
[
"download",
DUMMY_MODEL_ID,
"--repo-type",
"dataset",
"--revision",
"v1.0.0",
"--include",
"*.json",
"*.yaml",
"--exclude",
"*.log",
"*.txt",
"--force-download",
"--cache-dir",
"/tmp",
"--resume-download",
"--token",
"my-token",
"--quiet",
]
)
self.assertEqual(args.repo_id, DUMMY_MODEL_ID)
self.assertEqual(args.repo_type, "dataset")
self.assertEqual(args.revision, "v1.0.0")
self.assertEqual(args.include, ["*.json", "*.yaml"])
self.assertEqual(args.exclude, ["*.log", "*.txt"])
self.assertEqual(args.force_download, True)
self.assertEqual(args.cache_dir, "/tmp")
self.assertEqual(args.resume_download, True)
self.assertEqual(args.token, "my-token")
self.assertEqual(args.quiet, True)
self.assertEqual(args.func, DownloadCommand)

@patch("huggingface_hub.commands.download.hf_hub_download")
def test_download_file_from_revision(self, mock: Mock) -> None:
args = Namespace(
token="hf_****",
repo_id="author/dataset",
filenames=["README.md"],
repo_type="dataset",
revision="refs/pr/1",
include=None,
exclude=None,
force_download=False,
resume_download=False,
cache_dir=None,
quiet=False,
)

# Output path is printed to terminal once run is completed
with capture_output() as output:
DownloadCommand(args).run()
self.assertRegex(output.getvalue(), r"<MagicMock name='hf_hub_download\(\)' id='\d+'>")

mock.assert_called_once_with(
repo_id="author/dataset",
repo_type="dataset",
revision="refs/pr/1",
filename="README.md",
cache_dir=None,
resume_download=False,
force_download=False,
token="hf_****",
)

@patch("huggingface_hub.commands.download.snapshot_download")
def test_download_multiple_files(self, mock: Mock) -> None:
args = Namespace(
token="hf_****",
repo_id="author/model",
filenames=["README.md", "config.json"],
repo_type="model",
revision=None,
include=None,
exclude=None,
force_download=True,
resume_download=True,
cache_dir=None,
quiet=False,
)
DownloadCommand(args).run()

# Use `snapshot_download` to ensure all files comes from same revision
mock.assert_called_once_with(
repo_id="author/model",
repo_type="model",
revision=None,
allow_patterns=["README.md", "config.json"],
ignore_patterns=None,
resume_download=True,
force_download=True,
cache_dir=None,
token="hf_****",
)

@patch("huggingface_hub.commands.download.snapshot_download")
def test_download_with_patterns(self, mock: Mock) -> None:
args = Namespace(
token=None,
repo_id="author/model",
filenames=[],
repo_type="model",
revision=None,
include=["*.json"],
exclude=["data/*"],
force_download=True,
resume_download=True,
cache_dir=None,
quiet=False,
)
DownloadCommand(args).run()

# Use `snapshot_download` to ensure all files comes from same revision
mock.assert_called_once_with(
repo_id="author/model",
repo_type="model",
revision=None,
allow_patterns=["*.json"],
ignore_patterns=["data/*"],
resume_download=True,
force_download=True,
cache_dir=None,
token=None,
)

@patch("huggingface_hub.commands.download.snapshot_download")
def test_download_with_ignored_patterns(self, mock: Mock) -> None:
args = Namespace(
token=None,
repo_id="author/model",
filenames=["README.md", "config.json"],
repo_type="model",
revision=None,
include=["*.json"],
exclude=["data/*"],
force_download=True,
resume_download=True,
cache_dir=None,
quiet=False,
)

with self.assertWarns(UserWarning):
# warns that patterns are ignored
DownloadCommand(args).run()

mock.assert_called_once_with(
repo_id="author/model",
repo_type="model",
revision=None,
allow_patterns=["README.md", "config.json"], # `filenames` has priority over the patterns
ignore_patterns=None, # cleaned up
resume_download=True,
force_download=True,
cache_dir=None,
token=None,
)

# Same but quiet (no warnings)
args.quiet = True
with warnings.catch_warnings():
# Taken from https://docs.pytest.org/en/latest/how-to/capture-warnings.html#additional-use-cases-of-warnings-in-tests
warnings.simplefilter("error")
DownloadCommand(args).run()