Skip to content

Commit

Permalink
ENH: Add eval with post-processing, fix #472
Browse files Browse the repository at this point in the history
- Add post_tfm_kwargs to config/eval.py
- Add post_tfm_kwargs attribute to LearncurveConfig
- Add 'post_tfm_kwargs' option to config/valid.toml
- Add post_tfm_kwargs to LEARNCURVE section of vak/config/valid.toml
- Add use of post_tfm eval in engine.Model
- Add post_tfm_kwargs to core.eval and use with model
  - Add logic in core/eval.py to use post_tfm_kwargs to make post_tfm
  - Use multi_char_labels_to_single_char in core.eval,
    not in transforms, to make sure edit distance is computed
    correctl
- Add post_tfm parameter to vak.models.from_model_config_map
  - Add parameter and put in docstring,
  - Pass argument into Model.from_config
- Add post_tfm_kwargs to TeenyTweetyNet.from_config
- Add post_tfm_kwargs to unit test in test_core/test_eval.py
- Pass post_tfm_kwargs into core.eval in cli/eval.py
- Add parameter post_tfm_kwargs to vak.core.learncurve function,
  pass into calls to core.eval
- Pass post_tfm_kwargs into core.learncurve inside cli.learncurve
  • Loading branch information
NickleDave committed Feb 8, 2023
1 parent 8c6aee1 commit 738414c
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/vak/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .. import (
config,
core,
validators
)
from ..logging import config_logging_for_cli, log_version

Expand Down Expand Up @@ -65,4 +64,5 @@ def eval(toml_path):
spect_key=cfg.spect_params.spect_key,
timebins_key=cfg.spect_params.timebins_key,
device=cfg.eval.device,
post_tfm_kwargs=cfg.eval.post_tfm_kwargs,
)
1 change: 1 addition & 0 deletions src/vak/cli/learncurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def learning_curve(toml_path):
num_workers=cfg.learncurve.num_workers,
results_path=results_path,
previous_run_path=cfg.learncurve.previous_run_path,
post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs,
spect_key=cfg.spect_params.spect_key,
timebins_key=cfg.spect_params.timebins_key,
normalize_spectrograms=cfg.learncurve.normalize_spectrograms,
Expand Down
71 changes: 70 additions & 1 deletion src/vak/config/eval.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,63 @@
"""parses [EVAL] section of config"""
import attr
from attr import converters
from attr import converters, validators
from attr.validators import instance_of

from .validators import is_valid_model_name
from .. import device
from ..converters import comma_separated_list, expanded_user_path


def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict:
post_tfm_kwargs = dict(post_tfm_kwargs)

if 'min_segment_dur' not in post_tfm_kwargs:
# because there's no null in TOML,
# users leave arg out of config then we set it to None
post_tfm_kwargs['min_segment_dur'] = None
else:
post_tfm_kwargs['min_segment_dur'] = float(post_tfm_kwargs['min_segment_dur'])

if 'majority_vote' not in post_tfm_kwargs:
# set default for this one too
post_tfm_kwargs['majority_vote'] = False
else:
post_tfm_kwargs['majority_vote'] = bool(post_tfm_kwargs['majority_vote'])

return post_tfm_kwargs


def are_valid_post_tfm_kwargs(instance, attribute, value):
"""check if ``post_tfm_kwargs`` is valid"""
if not isinstance(value, dict):
raise TypeError(
"'post_tfm_kwargs' should be declared in toml config as an inline table "
f"that parses as a dict, but type was: {type(value)}. "
"Please declare in a similar fashion: `{majority_vote = True, min_segment_dur = 0.02}`"
)
if any(
[k not in {'majority_vote', 'min_segment_dur'} for k in value.keys()]
):
invalid_kwargs = [k for k in value.keys()
if k not in {'majority_vote', 'min_segment_dur'}]
raise ValueError(
f"Invalid keyword argument name specified for 'post_tfm_kwargs': {invalid_kwargs}."
"Valid names are: {'majority_vote', 'min_segment_dur'}"
)
if 'majority_vote' in value:
if not isinstance(value['majority_vote'], bool):
raise TypeError(
"'post_tfm_kwargs' keyword argument 'majority_vote' "
f"should be of type bool but was: {type(value['majority_vote'])}"
)
if 'min_segment_dur' in value:
if value['min_segment_dur'] and not isinstance(value['min_segment_dur'], float):
raise TypeError(
"'post_tfm_kwargs' keyword argument 'min_segment_dur' type "
f"should be float but was: {type(value['min_segment_dur'])}"
)


@attr.s
class EvalConfig:
"""class that represents [EVAL] section of config.toml file
Expand Down Expand Up @@ -36,6 +86,19 @@ class EvalConfig:
path to a saved SpectScaler object used to normalize spectrograms.
If spectrograms were normalized and this is not provided, will give
incorrect results.
post_tfm_kwargs : dict
Keyword arguments to post-processing transform.
If None, then no additional clean-up is applied
when transforming labeled timebins to segments,
the default behavior.
The transform used is
``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`.
Valid keyword argument names are 'majority_vote'
and 'min_segment_dur', and should be appropriate
values for those arguments: Boolean for ``majority_vote``,
a float value for ``min_segment_dur``.
See the docstring of the transform for more details on
these arguments and how they work.
"""
# required, external files
checkpoint_path = attr.ib(converter=expanded_user_path)
Expand All @@ -62,6 +125,12 @@ class EvalConfig:
default=None,
)

post_tfm_kwargs = attr.ib(
validator=validators.optional(are_valid_post_tfm_kwargs),
converter=converters.optional(convert_post_tfm_kwargs),
default={}, # empty dict so we can pass into transform with **kwargs expansion
)

# optional, data loader
num_workers = attr.ib(validator=instance_of(int), default=2)
device = attr.ib(validator=instance_of(str), default=device.get_default())
22 changes: 21 additions & 1 deletion src/vak/config/learncurve.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""parses [LEARNCURVE] section of config"""
import attr
from attr import converters
from attr import converters, validators
from attr.validators import instance_of

from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs
from .train import TrainConfig
from ..converters import expanded_user_path

Expand Down Expand Up @@ -49,10 +50,29 @@ class LearncurveConfig(TrainConfig):
previous_run_path : str
path to results directory from a previous run.
Used for training if use_train_subsets_from_previous_run is True.
post_tfm_kwargs : dict
Keyword arguments to post-processing transform.
If None, then no additional clean-up is applied
when transforming labeled timebins to segments,
the default behavior.
The transform used is
``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`.
Valid keyword argument names are 'majority_vote'
and 'min_segment_dur', and should be appropriate
values for those arguments: Boolean for ``majority_vote``,
a float value for ``min_segment_dur``.
See the docstring of the transform for more details on
these arguments and how they work.
"""
train_set_durs = attr.ib(validator=instance_of(list), kw_only=True)
num_replicates = attr.ib(validator=instance_of(int), kw_only=True)
previous_run_path = attr.ib(
converter=converters.optional(expanded_user_path),
default=None,
)

post_tfm_kwargs = attr.ib(
validator=validators.optional(are_valid_post_tfm_kwargs),
converter=converters.optional(convert_post_tfm_kwargs),
default={}, # empty dict so we can pass into transform with **kwargs expansion
)
3 changes: 2 additions & 1 deletion src/vak/config/valid.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ batch_size = 11
num_workers = 4
device = 'cuda'
spect_scaler_path = '/home/user/results_181014_194418/spect_scaler'

post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01}

[LEARNCURVE]
models = 'TweetyNet'
Expand All @@ -79,6 +79,7 @@ num_replicates = 2
csv_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv'
results_dir_made_by_main_script = '/some/path/to/learncurve/'
previous_run_path = '/some/path/to/learncurve/results_20210106_132152'
post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01}
num_workers = 4
device = 'cuda'

Expand Down
51 changes: 50 additions & 1 deletion src/vak/core/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
import torch.utils.data

from .. import (
files,
models,
timebins,
transforms,
validators
)
from ..datasets.vocal_dataset import VocalDataset
from ..labels import multi_char_labels_to_single_char


logger = logging.getLogger(__name__)
Expand All @@ -28,6 +31,7 @@ def eval(
num_workers,
split="test",
spect_scaler_path=None,
post_tfm_kwargs=None,
spect_key="s",
timebins_key="t",
device=None,
Expand Down Expand Up @@ -64,13 +68,34 @@ def eval(
If spectrograms were normalized and this is not provided, will give
incorrect results.
Default is None.
post_tfm_kwargs : dict
Keyword arguments to post-processing transform.
If None, then no additional clean-up is applied
when transforming labeled timebins to segments,
the default behavior. The transform used is
``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`.
Valid keyword argument names are 'majority_vote'
and 'min_segment_dur', and should be appropriate
values for those arguments: Boolean for ``majority_vote``,
a float value for ``min_segment_dur``.
See the docstring of the transform for more details on
these arguments and how they work.
spect_key : str
key for accessing spectrogram in files. Default is 's'.
timebins_key : str
key for accessing vector of time bins in files. Default is 't'.
device : str
Device on which to work with model + data.
Defaults to 'cuda' if torch.cuda.is_available is True.
Notes
-----
Note that unlike ``core.predict``, this function
can modify ``labelmap`` so that metrics like edit distance
are correctly computed, by converting any string labels
in ``labelmap`` with multiple characters
to (mock) single-character labels,
with ``vak.labels.multi_char_labels_to_single_char``.
"""
# ---- pre-conditions ----------------------------------------------------------------------------------------------
for path, path_name in zip(
Expand Down Expand Up @@ -102,6 +127,15 @@ def eval(
with labelmap_path.open("r") as f:
labelmap = json.load(f)

# replace any multiple character labels in mapping
# with dummy single-character labels
# so that we do not affect edit distance computation
# see https://github.com/NickleDave/vak/issues/373
labelmap_keys = [lbl for lbl in labelmap.keys() if lbl != 'unlabeled']
if any([len(label) > 1 for label in labelmap_keys]): # only re-map if necessary
# (to minimize chance of knock-on bugs)
labelmap = multi_char_labels_to_single_char(labelmap)

item_transform = transforms.get_defaults(
"eval",
spect_standardizer,
Expand Down Expand Up @@ -132,8 +166,23 @@ def eval(
if len(input_shape) == 4:
input_shape = input_shape[1:]

if post_tfm_kwargs:
dataset_df = pd.read_csv(csv_path)
# we use the timebins vector from the first spect path to get timebin dur.
# this is less careful than calling io.dataframe.validate_and_get_timebin_dur
# but it's also much faster, and we can assume dataframe was validated when it was made
spect_dict = files.spect.load(dataset_df['spect_path'].values[0])
timebin_dur = timebins.timebin_dur_from_vec(spect_dict[timebins_key])

post_tfm = transforms.labeled_timebins.PostProcess(
timebin_dur=timebin_dur,
**post_tfm_kwargs,
)
else:
post_tfm = None

models_map = models.from_model_config_map(
model_config_map, num_classes=len(labelmap), input_shape=input_shape
model_config_map, num_classes=len(labelmap), input_shape=input_shape, post_tfm=post_tfm
)

for model_name, model in models_map.items():
Expand Down
15 changes: 15 additions & 0 deletions src/vak/core/learncurve/learncurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
logger = logging.getLogger(__name__)


# TODO: add post_tfm_kwargs here
def learning_curve(
model_config_map,
train_set_durs,
Expand All @@ -32,6 +33,7 @@ def learning_curve(
root_results_dir=None,
results_path=None,
previous_run_path=None,
post_tfm_kwargs=None,
spect_key="s",
timebins_key="t",
normalize_spectrograms=True,
Expand Down Expand Up @@ -86,6 +88,18 @@ def learning_curve(
Typically directory will have a name like ``results_{timestamp}``
and the actual .csv splits will be in sub-directories with names
corresponding to the training set duration
post_tfm_kwargs : dict
Keyword arguments to post-processing transform.
If None, then no additional clean-up is applied
when transforming labeled timebins to segments,
the default behavior. The transform used is
``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`.
Valid keyword argument names are 'majority_vote'
and 'min_segment_dur', and should be appropriate
values for those arguments: Boolean for ``majority_vote``,
a float value for ``min_segment_dur``.
See the docstring of the transform for more details on
these arguments and how they work.
spect_key : str
key for accessing spectrogram in files. Default is 's'.
timebins_key : str
Expand Down Expand Up @@ -318,6 +332,7 @@ def learning_curve(
num_workers=num_workers,
split="test",
spect_scaler_path=spect_scaler_path,
post_tfm_kwargs=post_tfm_kwargs,
spect_key=spect_key,
timebins_key=timebins_key,
device=device,
Expand Down
Loading

0 comments on commit 738414c

Please sign in to comment.