Skip to content

Commit

Permalink
Progress bars (#261)
Browse files Browse the repository at this point in the history
* Progress bars

* Remove print statement

* add log progress to git_pull

* Windows compatibility
  • Loading branch information
LysandreJik authored Aug 10, 2021
1 parent 4a003e9 commit f5a35a3
Showing 1 changed file with 131 additions and 28 deletions.
159 changes: 131 additions & 28 deletions src/huggingface_hub/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
import os
import re
import subprocess
import tempfile
import threading
import time
from contextlib import contextmanager
from pathlib import Path
from typing import List, Optional, Union
from typing import Iterator, List, Optional, Union

from tqdm.auto import tqdm

from huggingface_hub.constants import REPO_TYPES_URL_PREFIXES

Expand Down Expand Up @@ -107,6 +112,98 @@ def is_git_ignored(filename: Union[str, Path]) -> bool:
return is_ignored


@contextmanager
def lfs_log_progress():
"""
This is a context manager that will log the Git LFS progress of cleaning, smudging, pulling and pushing.
"""

def output_progress(stopping_event: threading.Event):
"""
To be launched as a separate thread with an event meaning it should stop the tail.
"""
pbars = {}

def close_pbars():
for pbar in pbars.values():
pbar["bar"].update(pbar["bar"].total - pbar["past_bytes"])
pbar["bar"].refresh()
pbar["bar"].close()

def tail_file(filename) -> Iterator[str]:
"""
Creates a generator to be iterated through, which will return each line one by one.
Will stop tailing the file if the stopping_event is set.
"""
with open(filename, "r") as file:
current_line = ""
while True:
if stopping_event.is_set():
close_pbars()
break

line_bit = file.readline()
if line_bit is not None and not len(line_bit.strip()) == 0:
current_line += line_bit
if current_line.endswith("\n"):
yield current_line
current_line = ""
else:
time.sleep(1)

# If the file isn't created yet, wait for a few seconds before trying again.
# Can be interrupted with the stopping_event.
while not os.path.exists(os.environ["GIT_LFS_PROGRESS"]):
if stopping_event.is_set():
close_pbars()
break

time.sleep(2)

for line in tail_file(os.environ["GIT_LFS_PROGRESS"]):
state, file_progress, byte_progress, filename = line.split()
description = f"{state.capitalize()} file {filename}"

current_bytes, total_bytes = byte_progress.split("/")

current_bytes = int(current_bytes)
total_bytes = int(total_bytes)

if pbars.get((state, filename)) is None:
pbars[(state, filename)] = {
"bar": tqdm(
desc=description,
initial=current_bytes,
total=total_bytes,
unit="B",
unit_scale=True,
unit_divisor=1024,
),
"past_bytes": current_bytes,
}
else:
past_bytes = pbars[(state, filename)]["past_bytes"]
pbars[(state, filename)]["bar"].update(current_bytes - past_bytes)
pbars[(state, filename)]["past_bytes"] = current_bytes

current_lfs_progress_value = os.environ.get("GIT_LFS_PROGRESS", "")

with tempfile.TemporaryDirectory() as tmpdir:
os.environ["GIT_LFS_PROGRESS"] = os.path.join(tmpdir, "lfs_progress")

exit_event = threading.Event()
x = threading.Thread(target=output_progress, args=(exit_event,), daemon=True)
x.start()

try:
yield
finally:
exit_event.set()
x.join()

os.environ["GIT_LFS_PROGRESS"] = current_lfs_progress_value


class Repository:
"""
Helper class to wrap the git and git-lfs commands.
Expand Down Expand Up @@ -276,14 +373,15 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
# checks if repository is initialized in a empty repository or in one with files
if len(os.listdir(self.local_dir)) == 0:
logger.warning(f"Cloning {clean_repo_url} into local empty directory.")
subprocess.run(
["git", "clone", repo_url, "."],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=self.local_dir,
)
with lfs_log_progress():
subprocess.run(
f"git lfs clone {repo_url} .".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=self.local_dir,
)
else:
# Check if the folder is the root of a git repository
in_repository = is_git_repo(self.local_dir)
Expand Down Expand Up @@ -510,7 +608,7 @@ def lfs_enable_largefiles(self):
except subprocess.CalledProcessError as exc:
raise EnvironmentError(exc.stderr)

def auto_track_large_files(self, pattern="."):
def auto_track_large_files(self, pattern=".") -> List[str]:
"""
Automatically track large files with git-lfs
"""
Expand Down Expand Up @@ -546,6 +644,8 @@ def auto_track_large_files(self, pattern="."):
# Cleanup the .gitattributes if files were deleted
self.lfs_untrack(deleted_files)

return files_to_be_staged

def git_pull(self, rebase: Optional[bool] = False):
"""
git pull
Expand All @@ -554,14 +654,13 @@ def git_pull(self, rebase: Optional[bool] = False):
if rebase:
args.append("--rebase")
try:
subprocess.run(
args,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=self.local_dir,
)
with lfs_log_progress():
subprocess.run(
args,
check=True,
encoding="utf-8",
cwd=self.local_dir,
)
except subprocess.CalledProcessError as exc:
raise EnvironmentError(exc.stderr)

Expand All @@ -573,7 +672,11 @@ def git_add(self, pattern=".", auto_lfs_track=False):
than 10MB with `git-lfs`.
"""
if auto_lfs_track:
self.auto_track_large_files(pattern)
tracked_files = self.auto_track_large_files(pattern)
if len(tracked_files) > 0:
logger.warning(
"Adding files tracked by Git LFS. This may take a bit of time if the files are large."
)

try:
subprocess.run(
Expand Down Expand Up @@ -613,15 +716,15 @@ def git_push(self) -> str:
Returns url to commit on remote repo.
"""
try:
result = subprocess.run(
"git push".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=self.local_dir,
)
logger.info(result.stdout)
with lfs_log_progress():
subprocess.run(
"git push".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=self.local_dir,
)
except subprocess.CalledProcessError as exc:
raise EnvironmentError(exc.stderr)

Expand Down

0 comments on commit f5a35a3

Please sign in to comment.