diff --git a/nam/models/__init__.py b/nam/models/__init__.py index af2e6416..0fa7f849 100644 --- a/nam/models/__init__.py +++ b/nam/models/__init__.py @@ -7,7 +7,7 @@ """ from . import _base # noqa F401 -from . import _exportable # noqa F401 +from . import exportable # noqa F401 from . import losses # noqa F401 from . import wavenet # noqa F401 from .base import Model # noqa F401 diff --git a/nam/models/_base.py b/nam/models/_base.py index 763353a8..e77e797a 100644 --- a/nam/models/_base.py +++ b/nam/models/_base.py @@ -18,7 +18,7 @@ from .._core import InitializableFromConfig from ..data import wav_to_tensor -from ._exportable import Exportable +from .exportable import Exportable class _Base(nn.Module, InitializableFromConfig, Exportable): diff --git a/nam/models/_exportable.py b/nam/models/exportable.py similarity index 97% rename from nam/models/_exportable.py rename to nam/models/exportable.py index 42579941..9042eeb5 100644 --- a/nam/models/_exportable.py +++ b/nam/models/exportable.py @@ -39,6 +39,8 @@ class Exportable(abc.ABC): Interface for my custon export format for use in the plugin. """ + FILE_EXTENSION = ".nam" + def export( self, outdir: Path, @@ -66,7 +68,7 @@ def export( training = self.training self.eval() - with open(Path(outdir, f"{basename}.nam"), "w") as fp: + with open(Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp: json.dump(model_dict, fp) if include_snapshot: x, y = self._export_input_output() diff --git a/nam/train/_errors.py b/nam/train/_errors.py deleted file mode 100644 index b47c56aa..00000000 --- a/nam/train/_errors.py +++ /dev/null @@ -1,18 +0,0 @@ -# File: _errors.py -# Created Date: Saturday April 13th 2024 -# Author: Steven Atkinson (steven@atkinson.mn) - -""" -"What could go wrong?" -""" - -__all__ = ["IncompatibleCheckpointError"] - - -class IncompatibleCheckpointError(RuntimeError): - """ - Raised when model loading fails because the checkpoint didn't match the model - or its hyperparameters - """ - - pass diff --git a/nam/train/colab.py b/nam/train/colab.py index 9a6c3a1a..7c2aa1a4 100644 --- a/nam/train/colab.py +++ b/nam/train/colab.py @@ -81,7 +81,7 @@ def run( fit_cab: bool = False, ): """ - :param epochs: How amny epochs we'll train for. + :param epochs: How many epochs we'll train for. :param delay: How far the output algs the input due to round-trip latency during reamping, in samples. :param stage_1_channels: The number of channels in the WaveNet's first stage. diff --git a/nam/train/core.py b/nam/train/core.py index 13e3810c..e9e0e868 100644 --- a/nam/train/core.py +++ b/nam/train/core.py @@ -25,9 +25,9 @@ from ..data import Split, init_dataset, wav_to_np, wav_to_tensor from ..models import Model +from ..models.exportable import Exportable from ..models.losses import esr from ..util import filter_warnings -from ._errors import IncompatibleCheckpointError from ._version import PROTEUS_VERSION, Version __all__ = ["train"] @@ -870,7 +870,6 @@ def _get_configs( lr_decay: float, batch_size: int, fit_cab: bool, - checkpoint: Optional[Path] = None, ): def get_kwargs(data_info: _DataInfo): if data_info.major_version == 1: @@ -960,8 +959,6 @@ def get_kwargs(data_info: _DataInfo): if fit_cab: model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF - if checkpoint: - model_config["checkpoint_path"] = checkpoint if torch.cuda.is_available(): device_config = {"accelerator": "gpu", "devices": 1} @@ -1095,15 +1092,56 @@ def __init__(self, *args, **kwargs): self.patience = np.inf +class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint): + """ + Extension to model checkpoint to save a .nam file as well as the .ckpt file. + """ + + _NAM_FILE_EXTENSION = Exportable.FILE_EXTENSION + + @classmethod + def _get_nam_filepath(cls, filepath: str) -> Path: + """ + Given a .ckpt filepath, figure out a .nam for it. + """ + if not filepath.endswith(cls.FILE_EXTENSION): + raise ValueError( + f"Checkpoint filepath {filepath} doesn't end in expected extension " + f"{cls.FILE_EXTENSION}" + ) + return Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION) + + def _save_checkpoint(self, trainer: pl.Trainer, filepath: str): + # Save the .ckpt: + super()._save_checkpoint(trainer, filepath) + # Save the .nam: + nam_filepath = self._get_nam_filepath(filepath) + pl_model: Model = trainer.model + nam_model = pl_model.net + outdir = nam_filepath.parent + # HACK: Assume the extension + basename = nam_filepath.name[: -len(self._NAM_FILE_EXTENSION)] + nam_model.export( + outdir, + basename=basename, + ) + + def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: + super()._remove_checkpoint(trainer, filepath) + nam_path = self._get_nam_filepath(filepath) + if nam_path.exists(): + nam_path.unlink() + + def _get_callbacks(threshold_esr: Optional[float]): callbacks = [ - pl.callbacks.model_checkpoint.ModelCheckpoint( + _ModelCheckpoint( filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}", save_top_k=3, monitor="val_loss", every_n_epochs=1, ), - pl.callbacks.model_checkpoint.ModelCheckpoint( + _ModelCheckpoint( filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1 ), ] @@ -1135,7 +1173,6 @@ def train( local: bool = False, fit_cab: bool = False, threshold_esr: Optional[bool] = None, - checkpoint: Optional[Path] = None, ) -> Optional[Model]: """ :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`. @@ -1184,7 +1221,6 @@ def train( lr_decay, batch_size, fit_cab, - checkpoint=checkpoint, ) print("Starting training. It's time to kick ass and chew bubblegum!") @@ -1193,16 +1229,7 @@ def train( # * Model is re-instantiated after training anyways. # (Hacky) solution: set sample rate in model from dataloader after second # instantiation from final checkpoint. - try: - model = Model.init_from_config(model_config) - except RuntimeError as e: - if "Error(s) in loading state_dict for Model:" in str(e): - raise IncompatibleCheckpointError( - "Model initialization failed; the checkpoint used seems to be " - f"incompatible.\n\nOriginal error:\n\n{e}" - ) - else: - raise e + model = Model.init_from_config(model_config) train_dataloader, val_dataloader = _get_dataloaders( data_config, learning_config, model ) @@ -1212,6 +1239,8 @@ def train( f"{train_dataloader.dataset.sample_rate}, " f"{val_dataloader.dataset.sample_rate}" ) + sample_rate = train_dataloader.dataset.sample_rate + model.net.sample_rate = sample_rate trainer = pl.Trainer( callbacks=_get_callbacks(threshold_esr), @@ -1220,8 +1249,7 @@ def train( ) # Suppress the PossibleUserWarning about num_workers (Issue 345) with filter_warnings("ignore", category=PossibleUserWarning): - trainer_fit_kwargs = {} if checkpoint is None else {"ckpt_path": checkpoint} - trainer.fit(model, train_dataloader, val_dataloader, **trainer_fit_kwargs) + trainer.fit(model, train_dataloader, val_dataloader) # Go to best checkpoint best_checkpoint = trainer.checkpoint_callback.best_model_path @@ -1232,7 +1260,8 @@ def train( ) model.cpu() model.eval() - model.net.sample_rate = train_dataloader.dataset.sample_rate + # HACK set again + model.net.sample_rate = sample_rate def window_kwargs(version: Version): if version.major == 1: diff --git a/nam/train/gui.py b/nam/train/gui.py index ce977b9d..80bcb961 100644 --- a/nam/train/gui.py +++ b/nam/train/gui.py @@ -43,7 +43,6 @@ def _ensure_graceful_shutdowns(): from nam.models.metadata import GearType, UserMetadata, ToneType # Ok private access here--this is technically allowed access - from nam.train._errors import IncompatibleCheckpointError from nam.train._names import INPUT_BASENAMES, LATEST_VERSION _install_is_valid = True @@ -67,7 +66,6 @@ def _ensure_graceful_shutdowns(): _DEFAULT_DELAY = None _DEFAULT_IGNORE_CHECKS = False _DEFAULT_THRESHOLD_ESR = None -_DEFAULT_CHECKPOINT = None _ADVANCED_OPTIONS_LEFT_WIDTH = 12 _ADVANCED_OPTIONS_RIGHT_WIDTH = 12 @@ -84,7 +82,6 @@ class _AdvancedOptions(object): :param ignore_checks: Keep going even if a check says that something is wrong. :param threshold_esr: Stop training if the ESR gets better than this. If None, don't stop. - :param checkpoint: If provided, try to restart from this checkpoint. """ architecture: core.Architecture @@ -92,7 +89,6 @@ class _AdvancedOptions(object): latency: Optional[int] ignore_checks: bool threshold_esr: Optional[float] - checkpoint: Optional[Path] class _PathType(Enum): @@ -364,7 +360,6 @@ def __init__(self): _DEFAULT_DELAY, _DEFAULT_IGNORE_CHECKS, _DEFAULT_THRESHOLD_ESR, - _DEFAULT_CHECKPOINT, ) # Window to edit them: @@ -487,7 +482,6 @@ def _train(self): delay = self.advanced_options.latency file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val threshold_esr = self.advanced_options.threshold_esr - checkpoint = self.advanced_options.checkpoint # Advanced-er options # If you're poking around looking for these, then maybe it's time to learn to @@ -502,36 +496,27 @@ def _train(self): print("Now training {}".format(file)) basename = re.sub(r"\.wav$", "", file.split("/")[-1]) - try: - trained_model = core.train( - self._widgets[_GUIWidgets.INPUT_PATH].val, - file, - self._widgets[_GUIWidgets.TRAINING_DESTINATION].val, - epochs=num_epochs, - delay=delay, - architecture=architecture, - batch_size=batch_size, - lr=lr, - lr_decay=lr_decay, - seed=seed, - silent=self._checkboxes[ - _CheckboxKeys.SILENT_TRAINING - ].variable.get(), - save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), - modelname=basename, - ignore_checks=self._checkboxes[ - _CheckboxKeys.IGNORE_DATA_CHECKS - ].variable.get(), - local=True, - fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), - threshold_esr=threshold_esr, - checkpoint=checkpoint, - ) - except IncompatibleCheckpointError as e: - trained_model = None - self._wait_while_func( - _BasicModal, "Training failed due to incompatible checkpoint!" - ) + trained_model = core.train( + self._widgets[_GUIWidgets.INPUT_PATH].val, + file, + self._widgets[_GUIWidgets.TRAINING_DESTINATION].val, + epochs=num_epochs, + delay=delay, + architecture=architecture, + batch_size=batch_size, + lr=lr, + lr_decay=lr_decay, + seed=seed, + silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(), + save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), + modelname=basename, + ignore_checks=self._checkboxes[ + _CheckboxKeys.IGNORE_DATA_CHECKS + ].variable.get(), + local=True, + fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), + threshold_esr=threshold_esr, + ) if trained_model is None: print("Model training failed! Skip exporting...") @@ -755,17 +740,6 @@ def __init__(self, resume_main, parent: _GUI): type=_float_or_null, ) - # Restart from a checkpoint - self._frame_checkpoint = tk.Frame(self._root) - self._frame_checkpoint.pack() - self._path_button_checkpoint = _ClearablePathButton( - self._frame_checkpoint, - "Checkpoint", - "[Optional] Select a checkpoint (.ckpt file) to restart training from", - _PathType.FILE, - default=self._parent.advanced_options.checkpoint, - ) - # "Ok": apply and destory self._frame_ok = tk.Frame(self._root) self._frame_ok.pack() @@ -798,10 +772,6 @@ def _apply_and_destroy(self): self._parent.advanced_options.threshold_esr = ( None if threshold_esr == "null" else threshold_esr ) - checkpoint_path = self._path_button_checkpoint.val - self._parent.advanced_options.checkpoint = ( - None if checkpoint_path is None else Path(checkpoint_path) - ) self._root.destroy() self._resume_main() diff --git a/tests/test_nam/test_models/test_exportable.py b/tests/test_nam/test_models/test_exportable.py index 65898448..1fa53f60 100644 --- a/tests/test_nam/test_models/test_exportable.py +++ b/tests/test_nam/test_models/test_exportable.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn -from nam.models import _exportable +from nam.models import exportable from nam.models import metadata @@ -105,7 +105,7 @@ def test_include_snapshot(self, include_snapshot): @classmethod def _get_model(cls): - class Model(nn.Module, _exportable.Exportable): + class Model(nn.Module, exportable.Exportable): def __init__(self): super().__init__() self._scale = nn.Parameter(torch.tensor(0.0))