Skip to content

Commit

Permalink
fix lint (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
Natooz authored Mar 8, 2024
1 parent 5791bae commit 7fe2e15
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
38 changes: 38 additions & 0 deletions miditok/midi_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,44 @@ def load_tokens(

return self._convert_sequence_to_tokseq(json_content["ids"])

def save_pretrained(
self,
save_directory: str | Path,
*,
repo_id: str | None = None,
push_to_hub: bool = False,
**push_to_hub_kwargs,
) -> str | None:
"""
Save the tokenizer in local a directory.
Overridden from ``huggingface_hub.ModelHubMixin``.
Since v0.21 this method will automatically save ``self.config`` on after
calling ``self._save_pretrained``, which is unnecessary in our case.
:param save_directory: Path to directory in which the model weights and
configuration will be saved.
:param push_to_hub: Whether to push your model to the Huggingface Hub after
saving it.
:param repo_id: ID of your repository on the Hub. Used only if
`push_to_hub=True`. Will default to the folder name if not provided.
:param push_to_hub_kwargs: Additional key word arguments passed along to the
[`~ModelHubMixin.push_to_hub`] method.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)

# save model weights/files (framework-specific)
self._save_pretrained(save_directory)

# push to the Hub if required
if push_to_hub:
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, **kwargs)
return None

def _save_pretrained(self, *args, **kwargs) -> None: # noqa: ANN002
# called by `ModelHubMixin.from_pretrained`.
self.save_params(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
if TYPE_CHECKING:
from pathlib import Path

MAX_NUM_TRIES_HF_PUSH = 5
MAX_NUM_TRIES_HF_PUSH = 3
NUM_SECONDS_RETRY = 8

AUTO_TOKENIZER_CASES = [
Expand Down

0 comments on commit 7fe2e15

Please sign in to comment.