Skip to content

Commit

Permalink
expand architecture tests to cover modelzoo format
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Feb 18, 2025
1 parent 1f74601 commit 98272b3
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions tests/components/test_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CNNectomeUNetConfig,
WrappedArchitectureConfig,
ArchitectureConfig,
ModelZooConfig,
)

from pathlib import Path
Expand All @@ -22,6 +23,8 @@
from dacapo.train import train_run
from dacapo.validate import validate_run

from bioimageio.spec.model.v0_5 import Author

import pytest
from pytest_lazy_fixtures import lf

Expand All @@ -38,7 +41,8 @@ def build_test_architecture_config(
upsample: bool,
use_attention: bool,
padding: str,
wrapped: bool,
source: str,
tmp_path: Path,
) -> ArchitectureConfig:
"""
Build the simplest architecture config given the parameters.
Expand Down Expand Up @@ -96,7 +100,9 @@ def build_test_architecture_config(
padding=padding,
)

if wrapped:
if source == "config":
return cnnectom_unet_config
elif source == "module":
return WrappedArchitectureConfig(
name="test_wrapped",
module=cnnectom_unet_config.module(),
Expand All @@ -105,20 +111,30 @@ def build_test_architecture_config(
input_shape=input_shape,
scale=Coordinate(upsample_factors[0]) if upsample else None,
)
else:
return cnnectom_unet_config
elif source == "bioimage_modelzoo":
run = RunConfig(
architecture_config=cnnectom_unet_config,
name="dacapo_modelzoo_test",
)
run.save_bioimage_io_model(
tmp_path / "dacapo_modelzoo_test.zip", authors=[Author(name="Test")]
)
return ModelZooConfig(
model_id = tmp_path / "dacapo_modelzoo_test.zip", name="test_model_zoo"
)


# TODO: Move unet parameters that don't affect interaction with other modules
# to a separate architcture test
@pytest.mark.filterwarnings("ignore::FutureWarning") # pytest treats this as an error but we don't care for now
@pytest.mark.parametrize("data_dims", [2, 3])
@pytest.mark.parametrize("channels", [True, False])
@pytest.mark.parametrize("architecture_dims", [2, 3])
@pytest.mark.parametrize("upsample", [True, False])
@pytest.mark.parametrize("batch_norm", [True, False])
@pytest.mark.parametrize("use_attention", [True, False])
@pytest.mark.parametrize("padding", ["valid", "same"])
@pytest.mark.parametrize("wrapped", [True, False])
@pytest.mark.parametrize("source", ["config", "module", "bioimage_modelzoo"])
def test_architectures(
data_dims,
channels,
Expand All @@ -127,7 +143,8 @@ def test_architectures(
upsample,
use_attention,
padding,
wrapped,
source,
tmp_path,
):
architecture_config = build_test_architecture_config(
data_dims,
Expand All @@ -137,12 +154,18 @@ def test_architectures(
upsample,
use_attention,
padding,
wrapped,
source,
tmp_path
)

in_data = torch.rand(
(*(1, architecture_config.num_in_channels), *architecture_config.input_shape)
)
out_data = architecture_config.module()(in_data)
scale = architecture_config.scale(Coordinate((2,) * data_dims))
if upsample:
assert scale == Coordinate((1,) * data_dims)
else:
assert scale == Coordinate((2,) * data_dims)

assert out_data.shape[1] == architecture_config.num_out_channels

0 comments on commit 98272b3

Please sign in to comment.