diff --git a/miditok/midi_tokenizer.py b/miditok/midi_tokenizer.py index 9abf2c3a..5140e747 100644 --- a/miditok/midi_tokenizer.py +++ b/miditok/midi_tokenizer.py @@ -2770,6 +2770,44 @@ def load_tokens( return self._convert_sequence_to_tokseq(json_content["ids"]) + def save_pretrained( + self, + save_directory: str | Path, + *, + repo_id: str | None = None, + push_to_hub: bool = False, + **push_to_hub_kwargs, + ) -> str | None: + """ + Save the tokenizer in local a directory. + + Overridden from ``huggingface_hub.ModelHubMixin``. + Since v0.21 this method will automatically save ``self.config`` on after + calling ``self._save_pretrained``, which is unnecessary in our case. + + :param save_directory: Path to directory in which the model weights and + configuration will be saved. + :param push_to_hub: Whether to push your model to the Huggingface Hub after + saving it. + :param repo_id: ID of your repository on the Hub. Used only if + `push_to_hub=True`. Will default to the folder name if not provided. + :param push_to_hub_kwargs: Additional key word arguments passed along to the + [`~ModelHubMixin.push_to_hub`] method. + """ + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + # save model weights/files (framework-specific) + self._save_pretrained(save_directory) + + # push to the Hub if required + if push_to_hub: + kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input + if repo_id is None: + repo_id = save_directory.name # Defaults to `save_directory` name + return self.push_to_hub(repo_id=repo_id, **kwargs) + return None + def _save_pretrained(self, *args, **kwargs) -> None: # noqa: ANN002 # called by `ModelHubMixin.from_pretrained`. self.save_params(*args, **kwargs) diff --git a/tests/test_hf_hub.py b/tests/test_hf_hub.py index a99eeab1..d0eea820 100644 --- a/tests/test_hf_hub.py +++ b/tests/test_hf_hub.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from pathlib import Path -MAX_NUM_TRIES_HF_PUSH = 5 +MAX_NUM_TRIES_HF_PUSH = 3 NUM_SECONDS_RETRY = 8 AUTO_TOKENIZER_CASES = [