Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix models to log train loss on step, fixes #720 #722

Merged
merged 3 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/vak/cli/prep.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Function called by command-line interface for prep command"""
from __future__ import annotations

import pathlib
import shutil
import warnings
import pathlib

import toml

Expand All @@ -13,7 +13,9 @@
from ..config.validators import are_sections_valid


def purpose_from_toml(config_toml: dict, toml_path: str | pathlib.Path | None = None) -> str:
def purpose_from_toml(
config_toml: dict, toml_path: str | pathlib.Path | None = None
) -> str:
"""determine "purpose" from toml config,
i.e., the command that will be run after we ``prep`` the data.

Expand Down
4 changes: 3 additions & 1 deletion src/vak/common/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,9 @@ def map_annotated_to_annot(
reference section of the documentation:
https://vak.readthedocs.io/en/latest/reference/filenames.html
"""
if isinstance(annotated_files, np.ndarray): # e.g., vak DataFrame['spect_path'].values
if isinstance(
annotated_files, np.ndarray
): # e.g., vak DataFrame['spect_path'].values
annotated_files = annotated_files.tolist()

if annot_format in (
Expand Down
1 change: 0 additions & 1 deletion src/vak/common/files/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from . import spect
from .files import find_fname, from_dir


__all__ = [
"find_fname",
"from_dir",
Expand Down
8 changes: 6 additions & 2 deletions src/vak/common/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from ..common.typing import PathLike


def get_summary_writer(log_dir: PathLike, filename_suffix: str) -> SummaryWriter:
def get_summary_writer(
log_dir: PathLike, filename_suffix: str
) -> SummaryWriter:
"""Get an instance of ``tensorboard.SummaryWriter``,
to use with a vak.Model during training.

Expand Down Expand Up @@ -45,7 +47,9 @@ def get_summary_writer(log_dir: PathLike, filename_suffix: str) -> SummaryWriter


def events2df(
events_path: PathLike, size_guidance: dict | None = None, drop_wall_time: bool = True
events_path: PathLike,
size_guidance: dict | None = None,
drop_wall_time: bool = True,
) -> pd.DataFrame:
"""Convert :mod:`tensorboard` events file to pandas.DataFrame

Expand Down
1 change: 0 additions & 1 deletion src/vak/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
validators,
)


__all__ = [
"config",
"eval",
Expand Down
6 changes: 1 addition & 5 deletions src/vak/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from . import frame_classification, parametric_umap


__all__ = [
"frame_classification",
"parametric_umap"
]
__all__ = ["frame_classification", "parametric_umap"]
22 changes: 16 additions & 6 deletions src/vak/datasets/frame_classification/frames_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from . import constants, helper
from .metadata import Metadata
from ... import common


class FramesDataset:
Expand Down Expand Up @@ -119,7 +118,10 @@ def __init__(
Transform applied to each item :math:`(x, y)`
returned by :meth:`FramesDataset.__getitem__`.
"""
from ... import prep # avoid circular import, use for constants.INPUT_TYPES
from ... import (
prep,
) # avoid circular import, use for constants.INPUT_TYPES

if input_type not in prep.constants.INPUT_TYPES:
raise ValueError(
f"``input_type`` must be one of: {prep.constants.INPUT_TYPES}\n"
Expand Down Expand Up @@ -165,7 +167,7 @@ def _load_frames(self, frames_path):
the input to the frame classification model.
Loads audio or spectrogram, depending on
:attr:`self.input_type`.
This function assumes that audio is in wav format
This function assumes that audio is in wav format
and spectrograms are in npz files.
"""
return helper.load_frames(frames_path, self.input_type)
Expand Down Expand Up @@ -233,15 +235,23 @@ def from_dataset_path(

split_path = dataset_path / split
if subset:
sample_ids_path = split_path / helper.sample_ids_array_filename_for_subset(subset)
sample_ids_path = (
split_path
/ helper.sample_ids_array_filename_for_subset(subset)
)
else:
sample_ids_path = split_path / constants.SAMPLE_IDS_ARRAY_FILENAME
sample_ids = np.load(sample_ids_path)

if subset:
inds_in_sample_path = split_path / helper.inds_in_sample_array_filename_for_subset(subset)
inds_in_sample_path = (
split_path
/ helper.inds_in_sample_array_filename_for_subset(subset)
)
else:
inds_in_sample_path = split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME
inds_in_sample_path = (
split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME
)
inds_in_sample = np.load(inds_in_sample_path)

return cls(
Expand Down
8 changes: 4 additions & 4 deletions src/vak/datasets/frame_classification/helper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
"""Helper functions used with frame classification datasets."""
from __future__ import annotations

from . import constants
from ... import common
from . import constants


def sample_ids_array_filename_for_subset(subset: str) -> str:
"""Returns name of sample IDs array file for a subset of the training data."""
return constants.SAMPLE_IDS_ARRAY_FILENAME.replace(
'.npy', f'-{subset}.npy'
)
".npy", f"-{subset}.npy"
)


def inds_in_sample_array_filename_for_subset(subset: str) -> str:
"""Returns name of inds in sample array file for a subset of the training data."""
return constants.INDS_IN_SAMPLE_ARRAY_FILENAME.replace(
'.npy', f'-{subset}.npy'
".npy", f"-{subset}.npy"
)


Expand Down
32 changes: 20 additions & 12 deletions src/vak/datasets/frame_classification/window_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from . import constants, helper
from .metadata import Metadata
from ... import common


def get_window_inds(n_frames: int, window_size: int, stride: int = 1):
Expand Down Expand Up @@ -231,7 +230,10 @@ def __init__(
The transform applied to the target for the output
of the neural network :math:`y`.
"""
from ... import prep # avoid circular import, use for constants.INPUT_TYPES
from ... import (
prep,
) # avoid circular import, use for constants.INPUT_TYPES

if input_type not in prep.constants.INPUT_TYPES:
raise ValueError(
f"``input_type`` must be one of: {prep.constants.INPUT_TYPES}\n"
Expand Down Expand Up @@ -284,15 +286,15 @@ def _load_frames(self, frames_path):
the input to the frame classification model.
Loads audio or spectrogram, depending on
:attr:`self.input_type`.
This function assumes that audio is in wav format
This function assumes that audio is in wav format
and spectrograms are in npz files.
"""
return helper.load_frames(frames_path, self.input_type)

def __getitem__(self, idx):
window_idx = self.window_inds[idx]
sample_ids = self.sample_ids[
window_idx: window_idx + self.window_size
window_idx : window_idx + self.window_size # noqa: E203
]
uniq_sample_ids = np.unique(sample_ids)
if len(uniq_sample_ids) == 1:
Expand All @@ -309,9 +311,7 @@ def __getitem__(self, idx):
frame_labels = []
for sample_id in sorted(uniq_sample_ids):
frames_path = self.dataset_path / self.frames_paths[sample_id]
frames.append(
self._load_frames(frames_path)
)
frames.append(self._load_frames(frames_path))
frame_labels.append(
np.load(
self.dataset_path / self.frame_labels_paths[sample_id]
Expand All @@ -331,10 +331,10 @@ def __getitem__(self, idx):

inds_in_sample = self.inds_in_sample[window_idx]
frames = frames[
..., inds_in_sample: inds_in_sample + self.window_size
..., inds_in_sample : inds_in_sample + self.window_size # noqa: E203
]
frame_labels = frame_labels[
inds_in_sample: inds_in_sample + self.window_size
inds_in_sample : inds_in_sample + self.window_size # noqa: E203
]
if self.transform:
frames = self.transform(frames)
Expand Down Expand Up @@ -405,15 +405,23 @@ def from_dataset_path(

split_path = dataset_path / split
if subset:
sample_ids_path = split_path / helper.sample_ids_array_filename_for_subset(subset)
sample_ids_path = (
split_path
/ helper.sample_ids_array_filename_for_subset(subset)
)
else:
sample_ids_path = split_path / constants.SAMPLE_IDS_ARRAY_FILENAME
sample_ids = np.load(sample_ids_path)

if subset:
inds_in_sample_path = split_path / helper.inds_in_sample_array_filename_for_subset(subset)
inds_in_sample_path = (
split_path
/ helper.inds_in_sample_array_filename_for_subset(subset)
)
else:
inds_in_sample_path = split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME
inds_in_sample_path = (
split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME
)
inds_in_sample = np.load(inds_in_sample_path)

window_inds_path = split_path / constants.WINDOW_INDS_ARRAY_FILENAME
Expand Down
3 changes: 2 additions & 1 deletion src/vak/datasets/parametric_umap/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
from pynndescent import NNDescent
import scipy.sparse._coo
from pynndescent import NNDescent
from sklearn.utils import check_random_state
from torch.utils.data import Dataset

Expand All @@ -21,6 +21,7 @@

warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
from umap.umap_ import fuzzy_simplicial_set # noqa: E402

# isort: on


Expand Down
1 change: 0 additions & 1 deletion src/vak/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from . import eval_, frame_classification, parametric_umap
from .eval_ import eval


__all__ = [
"eval",
"eval_",
Expand Down
5 changes: 1 addition & 4 deletions src/vak/eval/parametric_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytorch_lightning as lightning
import torch.utils.data

from .. import datasets, models, transforms
from .. import models, transforms
from ..common import validators
from ..datasets.parametric_umap import ParametricUMAPDataset

Expand Down Expand Up @@ -85,9 +85,6 @@ def eval_parametric_umap_model(
logger.info(
f"Loading metadata from dataset path: {dataset_path}",
)
metadata = datasets.parametric_umap.Metadata.from_dataset_path(
dataset_path
)

if not validators.is_a_directory(output_dir):
raise NotADirectoryError(
Expand Down
8 changes: 1 addition & 7 deletions src/vak/learncurve/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
from . import (
curvefit,
dirname,
frame_classification,
learncurve,
)
from . import curvefit, dirname, frame_classification, learncurve
from .learncurve import learning_curve


__all__ = [
"curvefit",
"dirname",
Expand Down
1 change: 0 additions & 1 deletion src/vak/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .classification import Accuracy


__all__ = [
"Accuracy",
]
2 changes: 1 addition & 1 deletion src/vak/models/frame_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def training_step(self, batch: tuple, batch_idx: int):
x, y = batch[0], batch[1]
out = self.network(x)
loss = self.loss(out, y)
self.log("train_loss", loss)
self.log("train_loss", loss, on_step=True)
return loss

def validation_step(self, batch: tuple, batch_idx: int):
Expand Down
8 changes: 5 additions & 3 deletions src/vak/models/parametric_umap_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def training_step(self, batch, batch_idx):
loss_umap, loss_reconstruction, loss = self.loss(
embedding_to, embedding_from, reconstruction, before_encoding
)
self.log("train_umap_loss", loss_umap)
self.log("train_umap_loss", loss_umap, on_step=True)
if loss_reconstruction:
self.log("train_reconstruction_loss", loss_reconstruction)
self.log(
"train_reconstruction_loss", loss_reconstruction, on_step=True
)
# note if there's no ``loss_reconstruction``, then ``loss`` == ``loss_umap``
self.log("train_loss", loss)
self.log("train_loss", loss, on_step=True)
return loss

def validation_step(self, batch, batch_idx):
Expand Down
4 changes: 1 addition & 3 deletions src/vak/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def __getattr__(name: str) -> Any:
model_name_family_name_map[model_name] = family_name
return model_name_family_name_map
elif name == "MODEL_NAMES":
return list(
MODEL_REGISTRY.keys()
)
return list(MODEL_REGISTRY.keys())
else:
raise AttributeError(
f"Not an attribute of `vak.models.registry`: {name}"
Expand Down
1 change: 0 additions & 1 deletion src/vak/nn/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .dice import DiceLoss, dice_loss
from .umap import UmapLoss, umap_loss


__all__ = [
"DiceLoss",
"dice_loss",
Expand Down
1 change: 1 addition & 0 deletions src/vak/nn/loss/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

warnings.simplefilter("ignore", category=NumbaDeprecationWarning)
from umap.umap_ import find_ab_params # noqa : E402

# isort: on


Expand Down
1 change: 0 additions & 1 deletion src/vak/predict/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from . import frame_classification, parametric_umap, predict_
from .predict_ import predict


__all__ = [
"frame_classification",
"parametric_umap",
Expand Down
1 change: 0 additions & 1 deletion src/vak/prep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
)
from .prep_ import prep


__all__ = [
"audio_dataset",
"constants",
Expand Down
1 change: 0 additions & 1 deletion src/vak/prep/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ..common.typing import PathLike
from .spectrogram_dataset.audio_helper import files_from_dir


logger = logging.getLogger(__name__)


Expand Down
Loading