From 98272b3909bfe17f633b14db1046f0a91ad43b45 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 21 Jan 2025 09:54:45 -0800 Subject: [PATCH] expand architecture tests to cover modelzoo format --- tests/components/test_architectures.py | 37 +++++++++++++++++++++----- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/tests/components/test_architectures.py b/tests/components/test_architectures.py index 18503a32..45fa4cf3 100644 --- a/tests/components/test_architectures.py +++ b/tests/components/test_architectures.py @@ -14,6 +14,7 @@ CNNectomeUNetConfig, WrappedArchitectureConfig, ArchitectureConfig, + ModelZooConfig, ) from pathlib import Path @@ -22,6 +23,8 @@ from dacapo.train import train_run from dacapo.validate import validate_run +from bioimageio.spec.model.v0_5 import Author + import pytest from pytest_lazy_fixtures import lf @@ -38,7 +41,8 @@ def build_test_architecture_config( upsample: bool, use_attention: bool, padding: str, - wrapped: bool, + source: str, + tmp_path: Path, ) -> ArchitectureConfig: """ Build the simplest architecture config given the parameters. @@ -96,7 +100,9 @@ def build_test_architecture_config( padding=padding, ) - if wrapped: + if source == "config": + return cnnectom_unet_config + elif source == "module": return WrappedArchitectureConfig( name="test_wrapped", module=cnnectom_unet_config.module(), @@ -105,12 +111,22 @@ def build_test_architecture_config( input_shape=input_shape, scale=Coordinate(upsample_factors[0]) if upsample else None, ) - else: - return cnnectom_unet_config + elif source == "bioimage_modelzoo": + run = RunConfig( + architecture_config=cnnectom_unet_config, + name="dacapo_modelzoo_test", + ) + run.save_bioimage_io_model( + tmp_path / "dacapo_modelzoo_test.zip", authors=[Author(name="Test")] + ) + return ModelZooConfig( + model_id = tmp_path / "dacapo_modelzoo_test.zip", name="test_model_zoo" + ) # TODO: Move unet parameters that don't affect interaction with other modules # to a separate architcture test +@pytest.mark.filterwarnings("ignore::FutureWarning") # pytest treats this as an error but we don't care for now @pytest.mark.parametrize("data_dims", [2, 3]) @pytest.mark.parametrize("channels", [True, False]) @pytest.mark.parametrize("architecture_dims", [2, 3]) @@ -118,7 +134,7 @@ def build_test_architecture_config( @pytest.mark.parametrize("batch_norm", [True, False]) @pytest.mark.parametrize("use_attention", [True, False]) @pytest.mark.parametrize("padding", ["valid", "same"]) -@pytest.mark.parametrize("wrapped", [True, False]) +@pytest.mark.parametrize("source", ["config", "module", "bioimage_modelzoo"]) def test_architectures( data_dims, channels, @@ -127,7 +143,8 @@ def test_architectures( upsample, use_attention, padding, - wrapped, + source, + tmp_path, ): architecture_config = build_test_architecture_config( data_dims, @@ -137,12 +154,18 @@ def test_architectures( upsample, use_attention, padding, - wrapped, + source, + tmp_path ) in_data = torch.rand( (*(1, architecture_config.num_in_channels), *architecture_config.input_shape) ) out_data = architecture_config.module()(in_data) + scale = architecture_config.scale(Coordinate((2,) * data_dims)) + if upsample: + assert scale == Coordinate((1,) * data_dims) + else: + assert scale == Coordinate((2,) * data_dims) assert out_data.shape[1] == architecture_config.num_out_channels