Skip to content

Commit

Permalink
Add tests/test_models/test_windowed_frame_classification_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NickleDave committed Jan 21, 2023
1 parent bbac184 commit 2994eb3
Showing 1 changed file with 147 additions and 0 deletions.
147 changes: 147 additions & 0 deletions tests/test_models/test_windowed_frame_classification_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import inspect
import itertools

import pytest
import torch

import vak.models

from .test_definition import (
TeenyTweetyNetDefinition,
TweetyNetDefinition,
InvalidMetricsDictKeyModelDefinition,
)
from .test_tweetynet import LABELMAPS, INPUT_SHAPES


# pytest.mark.parametrize vals for test_init_with_definition
MODEL_DEFS = (
TweetyNetDefinition,
TeenyTweetyNetDefinition,
)

TEST_INIT_ARGVALS = itertools.product(LABELMAPS, INPUT_SHAPES, MODEL_DEFS)


class TestWindowedFrameClassificationModel:

@pytest.mark.parametrize(
'labelmap, input_shape, definition',
TEST_INIT_ARGVALS
)
def test_init(self,
labelmap,
input_shape,
definition,
monkeypatch):
"""Test Model.__init__ works as expected"""
# monkeypatch a definition so we can test __init__
definition = vak.models.definition.validate(definition)
monkeypatch.setattr(
vak.models.WindowedFrameClassificationModel,
'definition',
definition,
raising=False
)
# network has required args that need to be determined dynamically
network = definition.network(num_classes=len(labelmap), input_shape=input_shape)
model = vak.models.WindowedFrameClassificationModel(labelmap=labelmap, network=network)

# now test that attributes are what we expect
assert isinstance(model, vak.models.WindowedFrameClassificationModel)
for attr in ('network', 'loss', 'optimizer', 'metrics'):
assert hasattr(model, attr)
model_attr = getattr(model, attr)
definition_attr = getattr(definition, attr)
if inspect.isclass(definition_attr):
assert isinstance(model_attr, definition_attr)
elif isinstance(definition_attr, dict):
assert isinstance(model_attr, dict)
for definition_key, definition_val in definition_attr.items():
assert definition_key in model_attr
model_val = model_attr[definition_key]
if inspect.isclass(definition_val):
assert isinstance(model_val, definition_val)
else:
assert callable(definition_val)
assert model_val is definition_val
else:
# must be a function
assert callable(model_attr)
assert model_attr is definition_attr

@pytest.mark.parametrize(
'definition',
[
TeenyTweetyNetDefinition,
TweetyNetDefinition,
]
)
def test_from_config(self,
definition,
# our fixtures
specific_config,
# pytest fixtures
monkeypatch,
):
definition = vak.models.definition.validate(definition)
model_name = definition.__name__.replace('Definition', '').lower()
toml_path = specific_config('train', model_name, audio_format='cbin', annot_format='notmat')
cfg = vak.config.parse.from_toml_path(toml_path)

# stuff we need just to be able to instantiate network
labelmap = vak.labels.to_map(cfg.prep.labelset, map_unlabeled=True)
transform, target_transform = vak.transforms.get_defaults("train")
train_dataset = vak.datasets.WindowDataset.from_csv(
csv_path=cfg.train.csv_path,
split="train",
labelmap=labelmap,
window_size=cfg.dataloader.window_size,
spect_key='s',
timebins_key='t',
transform=transform,
target_transform=target_transform,
)
input_shape = train_dataset.shape

monkeypatch.setattr(
vak.models.WindowedFrameClassificationModel, 'definition', definition, raising=False
)

model_config_map = vak.config.models.map_from_path(toml_path, cfg.train.models)
model_name, config = list(model_config_map.items())[0]
config["network"].update(
num_classes=len(labelmap),
input_shape=input_shape,
)

model = vak.models.WindowedFrameClassificationModel.from_config(config=config, labelmap=labelmap)
assert isinstance(model, vak.models.WindowedFrameClassificationModel)

if 'network' in config:
if inspect.isclass(definition.network):
for network_kwarg, network_kwargval in config['network'].items():
assert hasattr(model.network, network_kwarg)
assert getattr(model.network, network_kwarg) == network_kwargval
elif isinstance(definition.network, dict):
for net_name, net_kwargs in config['network'].items():
for network_kwarg, network_kwargval in net_kwargs.items():
assert hasattr(model.network[net_name], network_kwarg)
assert getattr(model.network[net_name], network_kwarg) == network_kwargval

if 'loss' in config:
for loss_kwarg, loss_kwargval in config['loss'].items():
assert hasattr(model.loss, loss_kwarg)
assert getattr(model.loss, loss_kwarg) == loss_kwargval

if 'optimizer' in config:
for optimizer_kwarg, optimizer_kwargval in config['optimizer'].items():
assert optimizer_kwarg in model.optimizer.param_groups[0]
assert model.optimizer.param_groups[0][optimizer_kwarg] == optimizer_kwargval

if 'metrics' in config:
for metric_name, metric_kwargs in config['metrics'].items():
assert metric_name in model.metrics
for metric_kwarg, metric_kwargval in metric_kwargs.items():
assert hasattr(model.metrics[metric_name], metric_kwarg)
assert getattr(model.metrics[metric_name], metric_kwarg) == metric_kwargval

0 comments on commit 2994eb3

Please sign in to comment.