Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLN: Refactor model abstraction, fix #737 #726 #753

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
0ead8e9
Change type annotations on models.definition.validate to just be Type…
NickleDave May 5, 2024
4d324fb
Rewrite models.base.Model to not subclass LightningModule, and instea…
NickleDave May 5, 2024
a2d63f7
Change model decorator to not subclass family, and to instead subclas…
NickleDave May 5, 2024
752daa7
Rewrite FrameClassificationModel to subclass LightningModule directly…
NickleDave May 5, 2024
7041c7f
Rewrite ParemetricUMAPModel to subclass LightningModule directly, rem…
NickleDave May 5, 2024
ad81fb2
Rewrite base.Model class, move/rename methods so that we will get a s…
NickleDave May 6, 2024
bef395a
Add load_state_dict_from_path method to FrameClassificationModel and …
NickleDave May 6, 2024
30522e0
Change model_family decorator to check if family_class is a subtype o…
NickleDave May 6, 2024
3e9a6ed
Rewrite models.base.Model.__init__ to take definition and family attr…
NickleDave May 6, 2024
63d36b6
Fix how models.decorator.model makes Model instance -- we don't subcl…
NickleDave May 6, 2024
40e27f0
Fix FrameClassificationModel to subclass lightning.LightningModule (n…
NickleDave May 6, 2024
465de7c
Fix ParametricUMAPModel to subclass lightning.LightningModule (not li…
NickleDave May 6, 2024
84a9de6
Fix how we add a Model instance to MODEL_REGISTRY
NickleDave May 6, 2024
6d1062d
Fix src/vak/models/frame_classification_model.py to set network/loss/…
NickleDave May 6, 2024
e1250ba
Fix src/vak/models/parametric_umap_model.py to set network/loss/optim…
NickleDave May 6, 2024
ce79966
Fix how we get MODEL_FAMILY_FROM_NAME dict in models.registry.__getat…
NickleDave May 6, 2024
7314fa3
Fix classes in tests/test_models/conftest.py so we can use them to ru…
NickleDave May 6, 2024
3509785
Fix tests in tests/test_models/test_base.py
NickleDave May 6, 2024
74e7c37
Add method from_instances to vak.models.base.Model
NickleDave May 6, 2024
2e10679
Rename vak.models.base.Model -> vak.models.factory.ModelFactory
NickleDave May 6, 2024
8899da5
Add tests in tests/test_models/test_factory.py from test_frame_classi…
NickleDave May 6, 2024
f85de68
Fix unit test in tests/test_models/test_convencoder_umap.py
NickleDave May 6, 2024
553682a
Fix unit tests in tests/test_models/test_decorator.py
NickleDave May 6, 2024
efb8c76
Fix unit tests in tests/test_models/test_tweetynet.py
NickleDave May 6, 2024
c380c2a
Fix adding load_from_state_dict method to ParametricUMAPModel
NickleDave May 6, 2024
c0643d9
Fix unit tests in tests/test_models/test_frame_classification_model.py
NickleDave May 6, 2024
a22fb8d
Rename method in tests/test_models/test_convencoder_umap.py
NickleDave May 6, 2024
dcd5249
Fix unit tests in tests/test_models/test_ed_tcn.py
NickleDave May 6, 2024
877d8f9
Add a unit test from another test_models module to test_factory.py
NickleDave May 6, 2024
6dbb886
Add a unit test from another test_models module to test_factory.py
NickleDave May 6, 2024
cf8e863
Fix unit tests in tests/test_models/test_registry.py
NickleDave May 6, 2024
87c7e02
Remove unused fixture 'monkeypath' in tests/test_models/test_frame_cl…
NickleDave May 6, 2024
602add4
Fix unit tests in tests/test_models/test_parametric_umap_model.py
NickleDave May 6, 2024
a2a60b9
BUG: Fix how we check if we need to add empty dicts to model config …
NickleDave May 6, 2024
ce766d9
Rename model_class -> model_factory in src/vak/models/get.py
NickleDave May 6, 2024
3668f42
Clean up docstring in src/vak/models/factory.py
NickleDave May 6, 2024
2367280
Fix how we parametrize two unit tests in tests/test_config/test_model.py
NickleDave May 6, 2024
c0f0847
Fix ConvEncoderUMAP configs in tests/data_for_tests/configs to have n…
NickleDave May 6, 2024
9ca7c7e
Rewrite docstring, fix type annotations, rename vars for clarity in s…
NickleDave May 6, 2024
9435a61
Revise docstring in src/vak/models/definition.py
NickleDave May 6, 2024
18eac27
Revise type hinting + docstring in src/vak/models/get.py
NickleDave May 6, 2024
661bb95
Revise docstring + comment in src/vak/models/registry.py
NickleDave May 6, 2024
e500f98
Fix unit test in tests/test_models/test_factory.py
NickleDave May 6, 2024
140d98a
Fix ParametricUMAPModel to use a ModuleDict
NickleDave May 6, 2024
57ce5b4
Fix unit test in tests/test_models/test_convencoder_umap.py
NickleDave May 6, 2024
5d38d32
Fix unit test in tests/test_models/test_factory.py
NickleDave May 6, 2024
4862f3d
Fix unit test in tests/test_models/test_parametric_umap_model.py
NickleDave May 6, 2024
6b93377
Fix common.tensorboard.events2df to avoid pandas error about re-index…
NickleDave May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/vak/common/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ def events2df(
).set_index("step")
if drop_wall_time:
dfs[scalar_tag].drop("wall_time", axis=1, inplace=True)
df = pd.concat([v for k, v in dfs.items()], axis=1)
df = pd.concat([v for k, v in dfs.items() if k != "epoch"], axis=1)
return df
4 changes: 3 additions & 1 deletion src/vak/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def from_config_dict(cls, config_dict: dict):
f"Model name not found in registry: {model_name}\n"
f"Model names in registry:\n{MODEL_NAMES}"
)

# NOTE: we are getting model_config here
model_config = config_dict[model_name]
if not all(key in MODEL_TABLES for key in model_config.keys()):
invalid_keys = (
Expand All @@ -89,7 +91,7 @@ def from_config_dict(cls, config_dict: dict):
)
# for any tables not specified, default to empty dict so we can still use ``**`` operator on it
for model_table in MODEL_TABLES:
if model_table not in config_dict:
if model_table not in model_config:
model_config[model_table] = {}
return cls(name=model_name, **model_config)

Expand Down
8 changes: 4 additions & 4 deletions src/vak/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import base, decorator, definition, registry
from .base import Model
from . import decorator, definition, factory, registry
from .factory import ModelFactory
from .convencoder_umap import ConvEncoderUMAP
from .decorator import model
from .ed_tcn import ED_TCN
Expand All @@ -10,14 +10,14 @@
from .tweetynet import TweetyNet

__all__ = [
"base",
"factory",
"ConvEncoderUMAP",
"decorator",
"definition",
"ED_TCN",
"FrameClassificationModel",
"get",
"Model",
"ModelFactory",
"model",
"model_family",
"ParametricUMAPModel",
Expand Down
95 changes: 44 additions & 51 deletions src/vak/models/decorator.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
"""Decorator that makes a model class,
"""Decorator that makes a :class:`vak.models.ModelFactory`,
given a definition of the model,
and another class that represents a
and a :class:`lightning.LightningModule` that represents a
family of models that the new model belongs to.

The function returns a newly-created subclass
of the class representing the family of models.
The subclass can then be instantiated
and have all model methods.
The function returns a new instance of :class:`vak.models.ModelFactory`,
that can create new instances of the model with its
:meth:`~:class:`vak.models.ModelFactory.from_config` and
:meth:`~:class:`vak.models.ModelFactory.from_instances` methods.
"""

from __future__ import annotations

from typing import Type
from typing import Type, TYPE_CHECKING

import lightning

from .base import Model
from .definition import validate as validate_definition
from .registry import register_model

if TYPE_CHECKING:
from .factory import ModelFactory

class ModelDefinitionValidationError(Exception):
"""Exception raised when validating a model
Expand All @@ -28,16 +31,16 @@ class ModelDefinitionValidationError(Exception):
pass


def model(family: Type[Model]):
"""Decorator that makes a model class,
def model(family: lightning.pytorch.LightningModule):
"""Decorator that makes a :class:`vak.models.ModelFactory`,
given a definition of the model,
and another class that represents a
and a :class:`lightning.LightningModule` that represents a
family of models that the new model belongs to.

Returns a newly-created subclass
of the class representing the family of models.
The subclass can then be instantiated
and have all model methods.
The function returns a new instance of :class:`vak.models.ModelFactory`,
that can create new instances of the model with its
:meth:`~:class:`vak.models.ModelFactory.from_config` and
:meth:`~:class:`vak.models.ModelFactory.from_instances` methods.

Parameters
----------
Expand All @@ -46,50 +49,40 @@ def model(family: Type[Model]):
A class with all the class variables required
by :func:`vak.models.definition.validate`.
See docstring of that function for specification.
family : subclass of vak.models.Model
See also :class:`vak.models.definition.ModelDefinition`,
but note that it is not necessary to subclass
:class:`~vak.models.definition.ModelDefinition` to
define a model.
family : lightning.LightningModule
The class representing the family of models
that the new model will belong to.
E.g., :class:`vak.models.FrameClassificationModel`.
Should be a subclass of :class:`lightning.LightningModule`
that was registered with the
:func:`vak.models.registry.model_family` decorator.

Returns
-------
model : type
A sub-class of ``model_family``,
with attribute ``definition``,
model_factory : vak.models.ModelFactory
An instance of :class:`~vak.models.ModelFactory`,
with attribute ``definition`` and ``family``,
that will be used when making
new instances of the model.
new instances of the model by calling the
:meth:`~vak.models.ModelFactory.from_config` method
or the :meth:`~:class:`vak.models.ModelFactory.from_instances` method.
"""

def _model(definition: Type):
if not issubclass(family, Model):
raise TypeError(
"The ``family`` argument to the ``vak.models.model`` decorator"
"should be a subclass of ``vak.models.base.Model``,"
f"but the type was: {type(family)}, "
"which was not recognized as a subclass "
"of ``vak.models.base.Model``."
)

try:
validate_definition(definition)
except ValueError as err:
raise ModelDefinitionValidationError(
f"Validation failed for the following model definition:\n{definition}"
) from err
except TypeError as err:
raise ModelDefinitionValidationError(
f"Validation failed for the following model definition:\n{definition}"
) from err

attributes = dict(family.__dict__)
attributes.update({"definition": definition})
subclass_name = definition.__name__
subclass = type(subclass_name, (family,), attributes)
subclass.__module__ = definition.__module__

# finally, add model to registry
register_model(subclass)

return subclass
def _model(definition: Type) -> ModelFactory:
from .factory import ModelFactory # avoid circular import

model_factory = ModelFactory(
definition,
family
)
model_factory.__name__ = definition.__name__
model_factory.__doc__ = definition.__doc__
model_factory.__module__ = definition.__module__
register_model(model_factory)
return model_factory

return _model
17 changes: 10 additions & 7 deletions src/vak/models/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
class ModelDefinition:
"""A class that represents the definition of a neural network model.

Note it is **not** necessary to sub-class this class;
it exists mainly for type-checking purposes.

A model definition is a class that has the following class variables:
A model definition is any class that has the following class variables:

network: torch.nn.Module or dict
Neural network.
Expand All @@ -48,6 +45,12 @@ class ModelDefinition:
Used by ``vak.models.base.Model`` and its
sub-classes that represent model families. E.g., those classes will do:
``network = self.definition.network(**self.definition.default_config['network'])``.

Note it is **not** necessary to sub-class this class;
it exists mainly for type-checking purposes.

For more detail, see :func:`vak.models.decorator.model`
and :class:`vak.models.ModelFactory`.
"""

network: Union[torch.nn.Module, dict]
Expand All @@ -67,7 +70,7 @@ class ModelDefinition:
}


def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]:
def validate(definition: Type) -> Type:
"""Validate a model definition.

A model definition is a class that has the following class variables:
Expand Down Expand Up @@ -124,8 +127,8 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]:
converting it into a sub-class ofhttps://peps.python.org/pep-0416/
``vak.models.Model``.

It's also used by ``vak.models.Model``
to validate a definition when initializing
It's also used by :class:`vak.models.ModelFactory`,
to validate a definition before building
a new model instance from the definition.
"""
# need to set this default first
Expand Down
Loading
Loading