diff --git a/src/vak/cli/eval.py b/src/vak/cli/eval.py index 62b02599e..2361d6c1e 100644 --- a/src/vak/cli/eval.py +++ b/src/vak/cli/eval.py @@ -4,7 +4,6 @@ from .. import ( config, core, - validators ) from ..logging import config_logging_for_cli, log_version @@ -65,4 +64,5 @@ def eval(toml_path): spect_key=cfg.spect_params.spect_key, timebins_key=cfg.spect_params.timebins_key, device=cfg.eval.device, + post_tfm_kwargs=cfg.eval.post_tfm_kwargs, ) diff --git a/src/vak/cli/learncurve.py b/src/vak/cli/learncurve.py index a99a01edd..1f84b0d89 100644 --- a/src/vak/cli/learncurve.py +++ b/src/vak/cli/learncurve.py @@ -71,6 +71,7 @@ def learning_curve(toml_path): num_workers=cfg.learncurve.num_workers, results_path=results_path, previous_run_path=cfg.learncurve.previous_run_path, + post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs, spect_key=cfg.spect_params.spect_key, timebins_key=cfg.spect_params.timebins_key, normalize_spectrograms=cfg.learncurve.normalize_spectrograms, diff --git a/src/vak/config/eval.py b/src/vak/config/eval.py index e231c8be1..cc0bbe996 100644 --- a/src/vak/config/eval.py +++ b/src/vak/config/eval.py @@ -1,6 +1,6 @@ """parses [EVAL] section of config""" import attr -from attr import converters +from attr import converters, validators from attr.validators import instance_of from .validators import is_valid_model_name @@ -8,6 +8,56 @@ from ..converters import comma_separated_list, expanded_user_path +def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict: + post_tfm_kwargs = dict(post_tfm_kwargs) + + if 'min_segment_dur' not in post_tfm_kwargs: + # because there's no null in TOML, + # users leave arg out of config then we set it to None + post_tfm_kwargs['min_segment_dur'] = None + else: + post_tfm_kwargs['min_segment_dur'] = float(post_tfm_kwargs['min_segment_dur']) + + if 'majority_vote' not in post_tfm_kwargs: + # set default for this one too + post_tfm_kwargs['majority_vote'] = False + else: + post_tfm_kwargs['majority_vote'] = bool(post_tfm_kwargs['majority_vote']) + + return post_tfm_kwargs + + +def are_valid_post_tfm_kwargs(instance, attribute, value): + """check if ``post_tfm_kwargs`` is valid""" + if not isinstance(value, dict): + raise TypeError( + "'post_tfm_kwargs' should be declared in toml config as an inline table " + f"that parses as a dict, but type was: {type(value)}. " + "Please declare in a similar fashion: `{majority_vote = True, min_segment_dur = 0.02}`" + ) + if any( + [k not in {'majority_vote', 'min_segment_dur'} for k in value.keys()] + ): + invalid_kwargs = [k for k in value.keys() + if k not in {'majority_vote', 'min_segment_dur'}] + raise ValueError( + f"Invalid keyword argument name specified for 'post_tfm_kwargs': {invalid_kwargs}." + "Valid names are: {'majority_vote', 'min_segment_dur'}" + ) + if 'majority_vote' in value: + if not isinstance(value['majority_vote'], bool): + raise TypeError( + "'post_tfm_kwargs' keyword argument 'majority_vote' " + f"should be of type bool but was: {type(value['majority_vote'])}" + ) + if 'min_segment_dur' in value: + if value['min_segment_dur'] and not isinstance(value['min_segment_dur'], float): + raise TypeError( + "'post_tfm_kwargs' keyword argument 'min_segment_dur' type " + f"should be float but was: {type(value['min_segment_dur'])}" + ) + + @attr.s class EvalConfig: """class that represents [EVAL] section of config.toml file @@ -36,6 +86,19 @@ class EvalConfig: path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. + post_tfm_kwargs : dict + Keyword arguments to post-processing transform. + If None, then no additional clean-up is applied + when transforming labeled timebins to segments, + the default behavior. + The transform used is + ``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`. + Valid keyword argument names are 'majority_vote' + and 'min_segment_dur', and should be appropriate + values for those arguments: Boolean for ``majority_vote``, + a float value for ``min_segment_dur``. + See the docstring of the transform for more details on + these arguments and how they work. """ # required, external files checkpoint_path = attr.ib(converter=expanded_user_path) @@ -62,6 +125,12 @@ class EvalConfig: default=None, ) + post_tfm_kwargs = attr.ib( + validator=validators.optional(are_valid_post_tfm_kwargs), + converter=converters.optional(convert_post_tfm_kwargs), + default={}, # empty dict so we can pass into transform with **kwargs expansion + ) + # optional, data loader num_workers = attr.ib(validator=instance_of(int), default=2) device = attr.ib(validator=instance_of(str), default=device.get_default()) diff --git a/src/vak/config/learncurve.py b/src/vak/config/learncurve.py index a58e7d463..a70ba22b2 100644 --- a/src/vak/config/learncurve.py +++ b/src/vak/config/learncurve.py @@ -1,8 +1,9 @@ """parses [LEARNCURVE] section of config""" import attr -from attr import converters +from attr import converters, validators from attr.validators import instance_of +from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs from .train import TrainConfig from ..converters import expanded_user_path @@ -49,6 +50,19 @@ class LearncurveConfig(TrainConfig): previous_run_path : str path to results directory from a previous run. Used for training if use_train_subsets_from_previous_run is True. + post_tfm_kwargs : dict + Keyword arguments to post-processing transform. + If None, then no additional clean-up is applied + when transforming labeled timebins to segments, + the default behavior. + The transform used is + ``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`. + Valid keyword argument names are 'majority_vote' + and 'min_segment_dur', and should be appropriate + values for those arguments: Boolean for ``majority_vote``, + a float value for ``min_segment_dur``. + See the docstring of the transform for more details on + these arguments and how they work. """ train_set_durs = attr.ib(validator=instance_of(list), kw_only=True) num_replicates = attr.ib(validator=instance_of(int), kw_only=True) @@ -56,3 +70,9 @@ class LearncurveConfig(TrainConfig): converter=converters.optional(expanded_user_path), default=None, ) + + post_tfm_kwargs = attr.ib( + validator=validators.optional(are_valid_post_tfm_kwargs), + converter=converters.optional(convert_post_tfm_kwargs), + default={}, # empty dict so we can pass into transform with **kwargs expansion + ) diff --git a/src/vak/config/valid.toml b/src/vak/config/valid.toml index 5c2c6bd31..8c41e57fe 100644 --- a/src/vak/config/valid.toml +++ b/src/vak/config/valid.toml @@ -62,7 +62,7 @@ batch_size = 11 num_workers = 4 device = 'cuda' spect_scaler_path = '/home/user/results_181014_194418/spect_scaler' - +post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} [LEARNCURVE] models = 'TweetyNet' @@ -79,6 +79,7 @@ num_replicates = 2 csv_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv' results_dir_made_by_main_script = '/some/path/to/learncurve/' previous_run_path = '/some/path/to/learncurve/results_20210106_132152' +post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01} num_workers = 4 device = 'cuda' diff --git a/src/vak/core/eval.py b/src/vak/core/eval.py index 77765386c..5b821d322 100644 --- a/src/vak/core/eval.py +++ b/src/vak/core/eval.py @@ -8,11 +8,14 @@ import torch.utils.data from .. import ( + files, models, + timebins, transforms, validators ) from ..datasets.vocal_dataset import VocalDataset +from ..labels import multi_char_labels_to_single_char logger = logging.getLogger(__name__) @@ -28,6 +31,7 @@ def eval( num_workers, split="test", spect_scaler_path=None, + post_tfm_kwargs=None, spect_key="s", timebins_key="t", device=None, @@ -64,6 +68,18 @@ def eval( If spectrograms were normalized and this is not provided, will give incorrect results. Default is None. + post_tfm_kwargs : dict + Keyword arguments to post-processing transform. + If None, then no additional clean-up is applied + when transforming labeled timebins to segments, + the default behavior. The transform used is + ``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`. + Valid keyword argument names are 'majority_vote' + and 'min_segment_dur', and should be appropriate + values for those arguments: Boolean for ``majority_vote``, + a float value for ``min_segment_dur``. + See the docstring of the transform for more details on + these arguments and how they work. spect_key : str key for accessing spectrogram in files. Default is 's'. timebins_key : str @@ -71,6 +87,15 @@ def eval( device : str Device on which to work with model + data. Defaults to 'cuda' if torch.cuda.is_available is True. + + Notes + ----- + Note that unlike ``core.predict``, this function + can modify ``labelmap`` so that metrics like edit distance + are correctly computed, by converting any string labels + in ``labelmap`` with multiple characters + to (mock) single-character labels, + with ``vak.labels.multi_char_labels_to_single_char``. """ # ---- pre-conditions ---------------------------------------------------------------------------------------------- for path, path_name in zip( @@ -102,6 +127,15 @@ def eval( with labelmap_path.open("r") as f: labelmap = json.load(f) + # replace any multiple character labels in mapping + # with dummy single-character labels + # so that we do not affect edit distance computation + # see https://github.com/NickleDave/vak/issues/373 + labelmap_keys = [lbl for lbl in labelmap.keys() if lbl != 'unlabeled'] + if any([len(label) > 1 for label in labelmap_keys]): # only re-map if necessary + # (to minimize chance of knock-on bugs) + labelmap = multi_char_labels_to_single_char(labelmap) + item_transform = transforms.get_defaults( "eval", spect_standardizer, @@ -132,8 +166,23 @@ def eval( if len(input_shape) == 4: input_shape = input_shape[1:] + if post_tfm_kwargs: + dataset_df = pd.read_csv(csv_path) + # we use the timebins vector from the first spect path to get timebin dur. + # this is less careful than calling io.dataframe.validate_and_get_timebin_dur + # but it's also much faster, and we can assume dataframe was validated when it was made + spect_dict = files.spect.load(dataset_df['spect_path'].values[0]) + timebin_dur = timebins.timebin_dur_from_vec(spect_dict[timebins_key]) + + post_tfm = transforms.labeled_timebins.PostProcess( + timebin_dur=timebin_dur, + **post_tfm_kwargs, + ) + else: + post_tfm = None + models_map = models.from_model_config_map( - model_config_map, num_classes=len(labelmap), input_shape=input_shape + model_config_map, num_classes=len(labelmap), input_shape=input_shape, post_tfm=post_tfm ) for model_name, model in models_map.items(): diff --git a/src/vak/core/learncurve/learncurve.py b/src/vak/core/learncurve/learncurve.py index e8fc3d2e0..68e037ec3 100644 --- a/src/vak/core/learncurve/learncurve.py +++ b/src/vak/core/learncurve/learncurve.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) +# TODO: add post_tfm_kwargs here def learning_curve( model_config_map, train_set_durs, @@ -32,6 +33,7 @@ def learning_curve( root_results_dir=None, results_path=None, previous_run_path=None, + post_tfm_kwargs=None, spect_key="s", timebins_key="t", normalize_spectrograms=True, @@ -86,6 +88,18 @@ def learning_curve( Typically directory will have a name like ``results_{timestamp}`` and the actual .csv splits will be in sub-directories with names corresponding to the training set duration + post_tfm_kwargs : dict + Keyword arguments to post-processing transform. + If None, then no additional clean-up is applied + when transforming labeled timebins to segments, + the default behavior. The transform used is + ``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`. + Valid keyword argument names are 'majority_vote' + and 'min_segment_dur', and should be appropriate + values for those arguments: Boolean for ``majority_vote``, + a float value for ``min_segment_dur``. + See the docstring of the transform for more details on + these arguments and how they work. spect_key : str key for accessing spectrogram in files. Default is 's'. timebins_key : str @@ -318,6 +332,7 @@ def learning_curve( num_workers=num_workers, split="test", spect_scaler_path=spect_scaler_path, + post_tfm_kwargs=post_tfm_kwargs, spect_key=spect_key, timebins_key=timebins_key, device=device, diff --git a/src/vak/engine/model.py b/src/vak/engine/model.py index 163089625..f8522eb81 100644 --- a/src/vak/engine/model.py +++ b/src/vak/engine/model.py @@ -30,6 +30,14 @@ class Model: metrics : dict where keys are metric names, and values are callables that compute that a metric, e.g. accuracy. Metrics should accept arguments y_pred and y_true. + post_tfm : callable + Post-processing transform that models applies to network output + during evaluation. Default is None, in which case no post-processing + is applied, and network outputs are converted directly to string labels + with ``vak.transforms.labeled_timebins.ToLabels`` (that does not + apply any post-processing clean-ups). + To be valid, ``post_tfm`` must be an instance of + ``vak.transforms.labeled_timebins.PostProcess``. Attributes ---------- @@ -71,6 +79,7 @@ def __init__( loss, optimizer, metrics, + post_tfm=None, summary_writer=None, global_step=0, ): @@ -79,6 +88,13 @@ def __init__( self.loss = loss self.metrics = metrics + if post_tfm and not isinstance(post_tfm, transforms.labeled_timebins.PostProcess): + raise TypeError( + "post_tfm must be an instance of transforms.labeled_timebins.PostProcess, " + f"but type was: {type(post_tfm)}" + ) + self.post_tfm = post_tfm + self.summary_writer = summary_writer self.global_step = global_step # used for summary writer @@ -256,26 +272,20 @@ def _eval(self, eval_data): out = out[:, :, padding_mask] y_pred = y_pred[:, padding_mask] - if any( - [ - "levenshtein" in metric_name - for metric_name in self.metrics.keys() - ] - ) or any( - [ - "segment_error_rate" in metric_name - for metric_name in self.metrics.keys() - ] - ): - y_labels = transforms.labeled_timebins.lbl_tb2labels( - y.cpu().numpy(), eval_data.dataset.labelmap + y_labels = transforms.labeled_timebins.to_labels( + y.cpu().numpy(), eval_data.dataset.labelmap ) - y_pred_labels = transforms.labeled_timebins.lbl_tb2labels( - y_pred.cpu().numpy(), eval_data.dataset.labelmap + + # need to keep y_pred as tensor for acc metric; + # TODO: make post_tfm take tensors + y_pred_lbl_tb = y_pred.cpu().numpy() + if self.post_tfm: + y_pred_lbl_tb = self.post_tfm( + lbl_tb=y_pred_lbl_tb, ) - else: - y_labels = None - y_pred_labels = None + y_pred_labels = transforms.labeled_timebins.to_labels( + y_pred_lbl_tb, eval_data.dataset.labelmap + ) for metric_name, metric_callable in self.metrics.items(): if metric_name == "loss": @@ -465,7 +475,7 @@ def predict(self, pred_data, device=None): return self._predict(pred_data) @classmethod - def from_config(cls, config): + def from_config(cls, config, post_tfm=None): """any model that inherits from this class should do whatever it needs to in this factory method to create the network, optimizer, and loss, and then pass those to the init function diff --git a/src/vak/models/models.py b/src/vak/models/models.py index 91fadf665..c9f1d4b8a 100644 --- a/src/vak/models/models.py +++ b/src/vak/models/models.py @@ -32,7 +32,7 @@ def find(): yield entrypoint.name, entrypoint.load() -def from_model_config_map(model_config_map, num_classes, input_shape): +def from_model_config_map(model_config_map, num_classes, input_shape, post_tfm=None): """get models that are ready to train, given their names and configurations. Given a dictionary that maps model names to configurations, @@ -48,6 +48,14 @@ def from_model_config_map(model_config_map, num_classes, input_shape): input_shape : tuple, list e.g. (channels, height, width). Batch size is not required for input shape. + post_tfm : callable + Post-processing transform that models applies during evaluation. + Default is None, in which case the model defaults to using + ``vak.transforms.labeled_timebins.ToLabels`` (that does not + apply any post-processing clean-ups). + To be valid, ``post_tfm`` must be either an instance of + ``vak.transforms.labeled_timebins.ToLabels`` or + ``vak.transforms.labeled_timebins.ToLabelsWithPostprocessing``. Returns ------- @@ -67,7 +75,7 @@ def from_model_config_map(model_config_map, num_classes, input_shape): model = MODELS[model_name].from_config(config=model_config) except KeyError: model = MODELS[f"{model_name}Model"].from_config( - config=model_config + config=model_config, post_tfm=post_tfm, ) models_map[model_name] = model return models_map diff --git a/src/vak/models/teenytweetynet.py b/src/vak/models/teenytweetynet.py index a1c85d560..4ea3f53d3 100644 --- a/src/vak/models/teenytweetynet.py +++ b/src/vak/models/teenytweetynet.py @@ -119,7 +119,7 @@ def forward(self, x): class TeenyTweetyNetModel(Model): @classmethod - def from_config(cls, config): + def from_config(cls, config, post_tfm=None): network = TeenyTweetyNet(**config["network"]) loss = nn.CrossEntropyLoss(**config["loss"]) optimizer = torch.optim.Adam(params=network.parameters(), **config["optimizer"]) @@ -134,4 +134,5 @@ def from_config(cls, config): optimizer=optimizer, loss=loss, metrics=metrics, + post_tfm=post_tfm ) diff --git a/tests/test_core/test_eval.py b/tests/test_core/test_eval.py index 1a1fb0472..f9f3dbf62 100644 --- a/tests/test_core/test_eval.py +++ b/tests/test_core/test_eval.py @@ -16,6 +16,26 @@ def eval_output_matches_expected(model_config_map, output_dir): return True +# -- we do eval with all possible configurations of post_tfm_kwargs +POST_TFM_KWARGS = [ + # default, will use ToLabels + None, + # no cleanup but uses ToLabelsWithPostprocessing + {'majority_vote': False, 'min_segment_dur': None}, + # use ToLabelsWithPostprocessing with *just* majority_vote + {'majority_vote': True, 'min_segment_dur': None}, + # use ToLabelsWithPostprocessing with *just* min_segment_dur + {'majority_vote': False, 'min_segment_dur': 0.002}, + # use ToLabelsWithPostprocessing with majority_vote *and* min_segment_dur + {'majority_vote': True, 'min_segment_dur': 0.002}, +] + + +@pytest.fixture(params=POST_TFM_KWARGS) +def post_tfm_kwargs(request): + return request.param + + @pytest.mark.parametrize( "audio_format, spect_format, annot_format", [ @@ -23,7 +43,14 @@ def eval_output_matches_expected(model_config_map, output_dir): ], ) def test_eval( - audio_format, spect_format, annot_format, specific_config, tmp_path, model, device + audio_format, + spect_format, + annot_format, + specific_config, + tmp_path, + model, + device, + post_tfm_kwargs ): output_dir = tmp_path.joinpath( f"test_eval_{audio_format}_{spect_format}_{annot_format}" @@ -58,6 +85,7 @@ def test_eval( spect_key=cfg.spect_params.spect_key, timebins_key=cfg.spect_params.timebins_key, device=cfg.eval.device, + post_tfm_kwargs=post_tfm_kwargs, ) assert eval_output_matches_expected(model_config_map, output_dir)