Skip to content

Commit

Permalink
[FEATURE] Checkpoints save .nam files in addition to .ckpts (#408)
Browse files Browse the repository at this point in the history
* Update core.py

Extend PyTorch Lightning ModelCheckpoint to save and remove .nam
files alongside the .ckpt files.

* Remove unneeded checkpoint code

* Get sample rate into .nam checkpoints

* Update test_exportable.py
  • Loading branch information
sdatkinson authored May 13, 2024
1 parent 0ee6fd6 commit b71db72
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 96 deletions.
2 changes: 1 addition & 1 deletion nam/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from . import _base # noqa F401
from . import _exportable # noqa F401
from . import exportable # noqa F401
from . import losses # noqa F401
from . import wavenet # noqa F401
from .base import Model # noqa F401
Expand Down
2 changes: 1 addition & 1 deletion nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from .._core import InitializableFromConfig
from ..data import wav_to_tensor
from ._exportable import Exportable
from .exportable import Exportable


class _Base(nn.Module, InitializableFromConfig, Exportable):
Expand Down
4 changes: 3 additions & 1 deletion nam/models/_exportable.py → nam/models/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class Exportable(abc.ABC):
Interface for my custon export format for use in the plugin.
"""

FILE_EXTENSION = ".nam"

def export(
self,
outdir: Path,
Expand Down Expand Up @@ -66,7 +68,7 @@ def export(

training = self.training
self.eval()
with open(Path(outdir, f"{basename}.nam"), "w") as fp:
with open(Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp:
json.dump(model_dict, fp)
if include_snapshot:
x, y = self._export_input_output()
Expand Down
18 changes: 0 additions & 18 deletions nam/train/_errors.py

This file was deleted.

2 changes: 1 addition & 1 deletion nam/train/colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def run(
fit_cab: bool = False,
):
"""
:param epochs: How amny epochs we'll train for.
:param epochs: How many epochs we'll train for.
:param delay: How far the output algs the input due to round-trip latency during
reamping, in samples.
:param stage_1_channels: The number of channels in the WaveNet's first stage.
Expand Down
71 changes: 50 additions & 21 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@

from ..data import Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
from ..models.exportable import Exportable
from ..models.losses import esr
from ..util import filter_warnings
from ._errors import IncompatibleCheckpointError
from ._version import PROTEUS_VERSION, Version

__all__ = ["train"]
Expand Down Expand Up @@ -870,7 +870,6 @@ def _get_configs(
lr_decay: float,
batch_size: int,
fit_cab: bool,
checkpoint: Optional[Path] = None,
):
def get_kwargs(data_info: _DataInfo):
if data_info.major_version == 1:
Expand Down Expand Up @@ -960,8 +959,6 @@ def get_kwargs(data_info: _DataInfo):
if fit_cab:
model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT
model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF
if checkpoint:
model_config["checkpoint_path"] = checkpoint

if torch.cuda.is_available():
device_config = {"accelerator": "gpu", "devices": 1}
Expand Down Expand Up @@ -1095,15 +1092,56 @@ def __init__(self, *args, **kwargs):
self.patience = np.inf


class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
"""
Extension to model checkpoint to save a .nam file as well as the .ckpt file.
"""

_NAM_FILE_EXTENSION = Exportable.FILE_EXTENSION

@classmethod
def _get_nam_filepath(cls, filepath: str) -> Path:
"""
Given a .ckpt filepath, figure out a .nam for it.
"""
if not filepath.endswith(cls.FILE_EXTENSION):
raise ValueError(
f"Checkpoint filepath {filepath} doesn't end in expected extension "
f"{cls.FILE_EXTENSION}"
)
return Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION)

def _save_checkpoint(self, trainer: pl.Trainer, filepath: str):
# Save the .ckpt:
super()._save_checkpoint(trainer, filepath)
# Save the .nam:
nam_filepath = self._get_nam_filepath(filepath)
pl_model: Model = trainer.model
nam_model = pl_model.net
outdir = nam_filepath.parent
# HACK: Assume the extension
basename = nam_filepath.name[: -len(self._NAM_FILE_EXTENSION)]
nam_model.export(
outdir,
basename=basename,
)

def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None:
super()._remove_checkpoint(trainer, filepath)
nam_path = self._get_nam_filepath(filepath)
if nam_path.exists():
nam_path.unlink()


def _get_callbacks(threshold_esr: Optional[float]):
callbacks = [
pl.callbacks.model_checkpoint.ModelCheckpoint(
_ModelCheckpoint(
filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}",
save_top_k=3,
monitor="val_loss",
every_n_epochs=1,
),
pl.callbacks.model_checkpoint.ModelCheckpoint(
_ModelCheckpoint(
filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
),
]
Expand Down Expand Up @@ -1135,7 +1173,6 @@ def train(
local: bool = False,
fit_cab: bool = False,
threshold_esr: Optional[bool] = None,
checkpoint: Optional[Path] = None,
) -> Optional[Model]:
"""
:param threshold_esr: Stop training if ESR is better than this. Ignore if `None`.
Expand Down Expand Up @@ -1184,7 +1221,6 @@ def train(
lr_decay,
batch_size,
fit_cab,
checkpoint=checkpoint,
)

print("Starting training. It's time to kick ass and chew bubblegum!")
Expand All @@ -1193,16 +1229,7 @@ def train(
# * Model is re-instantiated after training anyways.
# (Hacky) solution: set sample rate in model from dataloader after second
# instantiation from final checkpoint.
try:
model = Model.init_from_config(model_config)
except RuntimeError as e:
if "Error(s) in loading state_dict for Model:" in str(e):
raise IncompatibleCheckpointError(
"Model initialization failed; the checkpoint used seems to be "
f"incompatible.\n\nOriginal error:\n\n{e}"
)
else:
raise e
model = Model.init_from_config(model_config)
train_dataloader, val_dataloader = _get_dataloaders(
data_config, learning_config, model
)
Expand All @@ -1212,6 +1239,8 @@ def train(
f"{train_dataloader.dataset.sample_rate}, "
f"{val_dataloader.dataset.sample_rate}"
)
sample_rate = train_dataloader.dataset.sample_rate
model.net.sample_rate = sample_rate

trainer = pl.Trainer(
callbacks=_get_callbacks(threshold_esr),
Expand All @@ -1220,8 +1249,7 @@ def train(
)
# Suppress the PossibleUserWarning about num_workers (Issue 345)
with filter_warnings("ignore", category=PossibleUserWarning):
trainer_fit_kwargs = {} if checkpoint is None else {"ckpt_path": checkpoint}
trainer.fit(model, train_dataloader, val_dataloader, **trainer_fit_kwargs)
trainer.fit(model, train_dataloader, val_dataloader)

# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
Expand All @@ -1232,7 +1260,8 @@ def train(
)
model.cpu()
model.eval()
model.net.sample_rate = train_dataloader.dataset.sample_rate
# HACK set again
model.net.sample_rate = sample_rate

def window_kwargs(version: Version):
if version.major == 1:
Expand Down
72 changes: 21 additions & 51 deletions nam/train/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def _ensure_graceful_shutdowns():
from nam.models.metadata import GearType, UserMetadata, ToneType

# Ok private access here--this is technically allowed access
from nam.train._errors import IncompatibleCheckpointError
from nam.train._names import INPUT_BASENAMES, LATEST_VERSION

_install_is_valid = True
Expand All @@ -67,7 +66,6 @@ def _ensure_graceful_shutdowns():
_DEFAULT_DELAY = None
_DEFAULT_IGNORE_CHECKS = False
_DEFAULT_THRESHOLD_ESR = None
_DEFAULT_CHECKPOINT = None

_ADVANCED_OPTIONS_LEFT_WIDTH = 12
_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
Expand All @@ -84,15 +82,13 @@ class _AdvancedOptions(object):
:param ignore_checks: Keep going even if a check says that something is wrong.
:param threshold_esr: Stop training if the ESR gets better than this. If None, don't
stop.
:param checkpoint: If provided, try to restart from this checkpoint.
"""

architecture: core.Architecture
num_epochs: int
latency: Optional[int]
ignore_checks: bool
threshold_esr: Optional[float]
checkpoint: Optional[Path]


class _PathType(Enum):
Expand Down Expand Up @@ -364,7 +360,6 @@ def __init__(self):
_DEFAULT_DELAY,
_DEFAULT_IGNORE_CHECKS,
_DEFAULT_THRESHOLD_ESR,
_DEFAULT_CHECKPOINT,
)
# Window to edit them:

Expand Down Expand Up @@ -487,7 +482,6 @@ def _train(self):
delay = self.advanced_options.latency
file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val
threshold_esr = self.advanced_options.threshold_esr
checkpoint = self.advanced_options.checkpoint

# Advanced-er options
# If you're poking around looking for these, then maybe it's time to learn to
Expand All @@ -502,36 +496,27 @@ def _train(self):
print("Now training {}".format(file))
basename = re.sub(r"\.wav$", "", file.split("/")[-1])

try:
trained_model = core.train(
self._widgets[_GUIWidgets.INPUT_PATH].val,
file,
self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
epochs=num_epochs,
delay=delay,
architecture=architecture,
batch_size=batch_size,
lr=lr,
lr_decay=lr_decay,
seed=seed,
silent=self._checkboxes[
_CheckboxKeys.SILENT_TRAINING
].variable.get(),
save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
modelname=basename,
ignore_checks=self._checkboxes[
_CheckboxKeys.IGNORE_DATA_CHECKS
].variable.get(),
local=True,
fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
threshold_esr=threshold_esr,
checkpoint=checkpoint,
)
except IncompatibleCheckpointError as e:
trained_model = None
self._wait_while_func(
_BasicModal, "Training failed due to incompatible checkpoint!"
)
trained_model = core.train(
self._widgets[_GUIWidgets.INPUT_PATH].val,
file,
self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
epochs=num_epochs,
delay=delay,
architecture=architecture,
batch_size=batch_size,
lr=lr,
lr_decay=lr_decay,
seed=seed,
silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(),
save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
modelname=basename,
ignore_checks=self._checkboxes[
_CheckboxKeys.IGNORE_DATA_CHECKS
].variable.get(),
local=True,
fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
threshold_esr=threshold_esr,
)

if trained_model is None:
print("Model training failed! Skip exporting...")
Expand Down Expand Up @@ -755,17 +740,6 @@ def __init__(self, resume_main, parent: _GUI):
type=_float_or_null,
)

# Restart from a checkpoint
self._frame_checkpoint = tk.Frame(self._root)
self._frame_checkpoint.pack()
self._path_button_checkpoint = _ClearablePathButton(
self._frame_checkpoint,
"Checkpoint",
"[Optional] Select a checkpoint (.ckpt file) to restart training from",
_PathType.FILE,
default=self._parent.advanced_options.checkpoint,
)

# "Ok": apply and destory
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
Expand Down Expand Up @@ -798,10 +772,6 @@ def _apply_and_destroy(self):
self._parent.advanced_options.threshold_esr = (
None if threshold_esr == "null" else threshold_esr
)
checkpoint_path = self._path_button_checkpoint.val
self._parent.advanced_options.checkpoint = (
None if checkpoint_path is None else Path(checkpoint_path)
)
self._root.destroy()
self._resume_main()

Expand Down
4 changes: 2 additions & 2 deletions tests/test_nam/test_models/test_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import torch.nn as nn

from nam.models import _exportable
from nam.models import exportable
from nam.models import metadata


Expand Down Expand Up @@ -105,7 +105,7 @@ def test_include_snapshot(self, include_snapshot):

@classmethod
def _get_model(cls):
class Model(nn.Module, _exportable.Exportable):
class Model(nn.Module, exportable.Exportable):
def __init__(self):
super().__init__()
self._scale = nn.Parameter(torch.tensor(0.0))
Expand Down

0 comments on commit b71db72

Please sign in to comment.