Skip to content

Commit

Permalink
automatic push in HFSummaryWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed Jun 7, 2023
1 parent c578b18 commit e5477c0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 54 deletions.
11 changes: 10 additions & 1 deletion src/huggingface_hub/_commit_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import time
from concurrent.futures import Future
from dataclasses import dataclass
from io import SEEK_END, SEEK_SET, BytesIO
from pathlib import Path
Expand Down Expand Up @@ -143,11 +144,19 @@ def stop(self) -> None:
def _run_scheduler(self) -> None:
"""Dumb thread waiting between each scheduled push to Hub."""
while True:
self.last_future = self.api.run_as_future(self._push_to_hub)
self.last_future = self.trigger()
time.sleep(self.every * 60)
if self.__stopped:
break

def trigger(self) -> Future[Optional[CommitInfo]]:
"""Trigger a `push_to_hub` and return a future.
This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
immediately, without waiting for the next scheduled commit.
"""
return self.api.run_as_future(self._push_to_hub)

def _push_to_hub(self) -> Optional[CommitInfo]:
if self.__stopped: # If stopped, already scheduled commits are ignored
return None
Expand Down
84 changes: 31 additions & 53 deletions src/huggingface_hub/_tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a logger to push training logs to the Hub, using Tensorboard."""
import os
import warnings
from concurrent.futures import Future
from typing import TYPE_CHECKING, List, Optional, Union

from .hf_api import create_repo, upload_folder
from huggingface_hub._commit_scheduler import CommitScheduler

from .utils import experimental, is_tensorboard_available


Expand All @@ -38,7 +36,8 @@ class HFSummaryWriter(SummaryWriter):
Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate
thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection
issue), the main script will not be interrupted.
issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every`
minutes (default to every 5 minutes).
<Tip warning={true}>
Expand All @@ -52,6 +51,8 @@ class HFSummaryWriter(SummaryWriter):
logdir (`str`, *optional*):
The directory where the logs will be written. If not specified, a local directory will be created by the
underlying `SummaryWriter` object.
commit_every (`int` or `float`, *optional*):
The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes.
repo_type (`str`, *optional*):
The type of the repo to which the logs will be pushed. Defaults to "model".
repo_revision (`str`, *optional*):
Expand All @@ -77,15 +78,20 @@ class HFSummaryWriter(SummaryWriter):
```py
>>> from huggingface_hub import HFSummaryWriter
>>> logger = HFSummaryWriter(repo_id="test_hf_logger")
# Logs are automatically pushed every 15 minutes
>>> logger = HFSummaryWriter(repo_id="test_hf_logger", commit_every=15)
>>> logger.add_scalar("a", 1)
>>> logger.add_scalar("b", 2)
>>> logger.push_to_hub()
...
# You can also trigger a push manually
>>> logger.scheduler.trigger()
```
```py
>>> from huggingface_hub import HFSummaryWriter
# Logs are automatically pushed every 5 minutes (default) + when exiting the context manager
>>> with HFSummaryWriter(repo_id="test_hf_logger") as logger:
... logger.add_scalar("a", 1)
... logger.add_scalar("b", 2)
Expand All @@ -106,6 +112,7 @@ def __init__(
repo_id: str,
*,
logdir: Optional[str] = None,
commit_every: Union[int, float] = 5,
repo_type: Optional[str] = None,
repo_revision: Optional[str] = None,
repo_private: bool = False,
Expand All @@ -118,55 +125,26 @@ def __init__(
# Initialize SummaryWriter
super().__init__(logdir=logdir, **kwargs)

# Create repo if doesn't exist
repo_url = create_repo(repo_id=repo_id, repo_type=repo_type, token=token, exist_ok=True, private=repo_private)
self.repo_id = repo_url.repo_id
print(f"Logs will be pushed to {repo_url}")
# Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it.
if not isinstance(self.logdir, str):
raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.")

# Set Hub-related attributes
self.repo_type = repo_type
self.repo_revision = repo_revision
self.path_in_repo = path_in_repo
self.token = token
self.repo_allow_patterns = repo_allow_patterns
self.repo_ignore_patterns = repo_ignore_patterns
# Initialize scheduler
self.scheduler = CommitScheduler(
folder_path=self.logdir,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
revision=repo_revision,
private=repo_private,
token=token,
allow_patterns=repo_allow_patterns,
ignore_patterns=repo_ignore_patterns,
every=commit_every,
)

def __exit__(self, exc_type, exc_val, exc_tb):
"""Push to hub in a non-blocking way when exiting the logger's context manager."""
super().__exit__(exc_type, exc_val, exc_tb)
future = self.push_to_hub(commit_message="Closing HFSummaryWriter.")
future = self.scheduler.trigger()
future.result()

def push_to_hub(
self, commit_message: Optional[str] = None, commit_description: Optional[str] = None
) -> Optional[Future[str]]:
"""
Push the logs to the Hub asynchronously.
Args:
commit_message (`str`, *optional*):
The summary / title / first line of the pushed commit. Defaults to "Upload training logs using HFSummaryWriter.".
commit_description (`str`, *optional*):
The description of the pushed commit.
Returns:
`Future[str]`: A future object that will yield the commit url when the upload is complete. Can be used to
check the status of the upload. Returns None if `self.logdir` is an empty directory.
"""
if not os.path.isdir(self.logdir):
warnings.warn(f"Cannot push log to hub: {self.logdir} is not a directory.")
return None

return upload_folder(
repo_id=self.repo_id,
folder_path=self.logdir,
path_in_repo=self.path_in_repo,
commit_message=commit_message or "Upload training logs using HFSummaryWriter.",
commit_description=commit_description,
token=self.token,
repo_type=self.repo_type,
revision=self.repo_revision,
allow_patterns=self.repo_allow_patterns,
ignore_patterns=self.repo_ignore_patterns,
run_as_future=True,
)

0 comments on commit e5477c0

Please sign in to comment.