Skip to content

Commit

Permalink
BUG: Fix WindowedFrameClassificationModel attribute
Browse files Browse the repository at this point in the history
Change `lbl_tb2labels` to an instance of
`transforms.labeled_timebins.ToLabels`.
Fix needed after rebasing version 1.0 on #621
  • Loading branch information
NickleDave committed Feb 11, 2023
1 parent 3ffd729 commit 948efa0
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/vak/models/windowed_frame_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
in a window, e.g., each time bin in
a window from a spectrogram."""
from __future__ import annotations
import functools
from typing import Callable, ClassVar, Mapping, Type

import torch

from . import base
from .definition import ModelDefinition
from .. import labeled_timebins
from .. import transforms


class WindowedFrameClassificationModel(base.Model):
Expand Down Expand Up @@ -75,8 +74,7 @@ def __init__(self,
"""
super().__init__(network=network, loss=loss,
optimizer=optimizer, metrics=metrics)
lbl_tb2labels = functools.partial(labeled_timebins.lbl_tb2labels,
labels_mapping=labelmap)
lbl_tb2labels = transforms.labeled_timebins.ToLabels(labelmap=labelmap)
self.lbl_tb2labels = lbl_tb2labels
self.post_tfm = post_tfm

Expand Down

0 comments on commit 948efa0

Please sign in to comment.