From 948efa0c8cfa4924521b040a82796d03e00bce14 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 10 Feb 2023 19:56:12 -0500 Subject: [PATCH] BUG: Fix WindowedFrameClassificationModel attribute Change `lbl_tb2labels` to an instance of `transforms.labeled_timebins.ToLabels`. Fix needed after rebasing version 1.0 on #621 --- src/vak/models/windowed_frame_classification_model.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/vak/models/windowed_frame_classification_model.py b/src/vak/models/windowed_frame_classification_model.py index bf91fd626..7f1c86855 100644 --- a/src/vak/models/windowed_frame_classification_model.py +++ b/src/vak/models/windowed_frame_classification_model.py @@ -3,14 +3,13 @@ in a window, e.g., each time bin in a window from a spectrogram.""" from __future__ import annotations -import functools from typing import Callable, ClassVar, Mapping, Type import torch from . import base from .definition import ModelDefinition -from .. import labeled_timebins +from .. import transforms class WindowedFrameClassificationModel(base.Model): @@ -75,8 +74,7 @@ def __init__(self, """ super().__init__(network=network, loss=loss, optimizer=optimizer, metrics=metrics) - lbl_tb2labels = functools.partial(labeled_timebins.lbl_tb2labels, - labels_mapping=labelmap) + lbl_tb2labels = transforms.labeled_timebins.ToLabels(labelmap=labelmap) self.lbl_tb2labels = lbl_tb2labels self.post_tfm = post_tfm