From f5a35a3c4096452f32d75ed3c78aabbc8b2905b5 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 10 Aug 2021 13:10:49 +0200 Subject: [PATCH] Progress bars (#261) * Progress bars * Remove print statement * add log progress to git_pull * Windows compatibility --- src/huggingface_hub/repository.py | 159 ++++++++++++++++++++++++------ 1 file changed, 131 insertions(+), 28 deletions(-) diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py index 1b40aaa36b..e4c77f3f5b 100644 --- a/src/huggingface_hub/repository.py +++ b/src/huggingface_hub/repository.py @@ -2,9 +2,14 @@ import os import re import subprocess +import tempfile +import threading +import time from contextlib import contextmanager from pathlib import Path -from typing import List, Optional, Union +from typing import Iterator, List, Optional, Union + +from tqdm.auto import tqdm from huggingface_hub.constants import REPO_TYPES_URL_PREFIXES @@ -107,6 +112,98 @@ def is_git_ignored(filename: Union[str, Path]) -> bool: return is_ignored +@contextmanager +def lfs_log_progress(): + """ + This is a context manager that will log the Git LFS progress of cleaning, smudging, pulling and pushing. + """ + + def output_progress(stopping_event: threading.Event): + """ + To be launched as a separate thread with an event meaning it should stop the tail. + """ + pbars = {} + + def close_pbars(): + for pbar in pbars.values(): + pbar["bar"].update(pbar["bar"].total - pbar["past_bytes"]) + pbar["bar"].refresh() + pbar["bar"].close() + + def tail_file(filename) -> Iterator[str]: + """ + Creates a generator to be iterated through, which will return each line one by one. + Will stop tailing the file if the stopping_event is set. + """ + with open(filename, "r") as file: + current_line = "" + while True: + if stopping_event.is_set(): + close_pbars() + break + + line_bit = file.readline() + if line_bit is not None and not len(line_bit.strip()) == 0: + current_line += line_bit + if current_line.endswith("\n"): + yield current_line + current_line = "" + else: + time.sleep(1) + + # If the file isn't created yet, wait for a few seconds before trying again. + # Can be interrupted with the stopping_event. + while not os.path.exists(os.environ["GIT_LFS_PROGRESS"]): + if stopping_event.is_set(): + close_pbars() + break + + time.sleep(2) + + for line in tail_file(os.environ["GIT_LFS_PROGRESS"]): + state, file_progress, byte_progress, filename = line.split() + description = f"{state.capitalize()} file {filename}" + + current_bytes, total_bytes = byte_progress.split("/") + + current_bytes = int(current_bytes) + total_bytes = int(total_bytes) + + if pbars.get((state, filename)) is None: + pbars[(state, filename)] = { + "bar": tqdm( + desc=description, + initial=current_bytes, + total=total_bytes, + unit="B", + unit_scale=True, + unit_divisor=1024, + ), + "past_bytes": current_bytes, + } + else: + past_bytes = pbars[(state, filename)]["past_bytes"] + pbars[(state, filename)]["bar"].update(current_bytes - past_bytes) + pbars[(state, filename)]["past_bytes"] = current_bytes + + current_lfs_progress_value = os.environ.get("GIT_LFS_PROGRESS", "") + + with tempfile.TemporaryDirectory() as tmpdir: + os.environ["GIT_LFS_PROGRESS"] = os.path.join(tmpdir, "lfs_progress") + + exit_event = threading.Event() + x = threading.Thread(target=output_progress, args=(exit_event,), daemon=True) + x.start() + + try: + yield + finally: + exit_event.set() + x.join() + + os.environ["GIT_LFS_PROGRESS"] = current_lfs_progress_value + + class Repository: """ Helper class to wrap the git and git-lfs commands. @@ -276,14 +373,15 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non # checks if repository is initialized in a empty repository or in one with files if len(os.listdir(self.local_dir)) == 0: logger.warning(f"Cloning {clean_repo_url} into local empty directory.") - subprocess.run( - ["git", "clone", repo_url, "."], - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - check=True, - encoding="utf-8", - cwd=self.local_dir, - ) + with lfs_log_progress(): + subprocess.run( + f"git lfs clone {repo_url} .".split(), + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + check=True, + encoding="utf-8", + cwd=self.local_dir, + ) else: # Check if the folder is the root of a git repository in_repository = is_git_repo(self.local_dir) @@ -510,7 +608,7 @@ def lfs_enable_largefiles(self): except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) - def auto_track_large_files(self, pattern="."): + def auto_track_large_files(self, pattern=".") -> List[str]: """ Automatically track large files with git-lfs """ @@ -546,6 +644,8 @@ def auto_track_large_files(self, pattern="."): # Cleanup the .gitattributes if files were deleted self.lfs_untrack(deleted_files) + return files_to_be_staged + def git_pull(self, rebase: Optional[bool] = False): """ git pull @@ -554,14 +654,13 @@ def git_pull(self, rebase: Optional[bool] = False): if rebase: args.append("--rebase") try: - subprocess.run( - args, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - check=True, - encoding="utf-8", - cwd=self.local_dir, - ) + with lfs_log_progress(): + subprocess.run( + args, + check=True, + encoding="utf-8", + cwd=self.local_dir, + ) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) @@ -573,7 +672,11 @@ def git_add(self, pattern=".", auto_lfs_track=False): than 10MB with `git-lfs`. """ if auto_lfs_track: - self.auto_track_large_files(pattern) + tracked_files = self.auto_track_large_files(pattern) + if len(tracked_files) > 0: + logger.warning( + "Adding files tracked by Git LFS. This may take a bit of time if the files are large." + ) try: subprocess.run( @@ -613,15 +716,15 @@ def git_push(self) -> str: Returns url to commit on remote repo. """ try: - result = subprocess.run( - "git push".split(), - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - check=True, - encoding="utf-8", - cwd=self.local_dir, - ) - logger.info(result.stdout) + with lfs_log_progress(): + subprocess.run( + "git push".split(), + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + check=True, + encoding="utf-8", + cwd=self.local_dir, + ) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr)