diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index becad754..98faf9b6 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -174,6 +174,7 @@ def module(self) -> torch.nn.Module: + [True] * (len(downsample_factors) - 1), use_attention=self.use_attention, batch_norm=self.batch_norm, + dims=self.dims, ) if self.upsample_factors is not None and len(self.upsample_factors) > 0: layers = [unet] @@ -190,7 +191,7 @@ def module(self) -> torch.nn.Module: conv = ConvPass( self.fmaps_out, self.fmaps_out, - kernel_size_up[-1], + kernel_size_down[-1], activation="ReLU", batch_norm=self.batch_norm, ) diff --git a/dacapo/experiments/architectures/cnnectome_unet_impl.py b/dacapo/experiments/architectures/cnnectome_unet_impl.py index 963fcf21..9ed57076 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_impl.py +++ b/dacapo/experiments/architectures/cnnectome_unet_impl.py @@ -84,6 +84,7 @@ def __init__( activation_on_upsample=False, use_attention=False, batch_norm=True, + dims: int | None = None, ): """ Create a U-Net:: @@ -200,7 +201,10 @@ def __init__( else upsample_channel_contraction ) - self.dims = len(downsample_factors[0]) + if dims is None: + self.dims = len(downsample_factors[0]) + else: + self.dims = dims self.use_attention = use_attention self.batch_norm = batch_norm diff --git a/tests/operations/helpers.py b/tests/operations/helpers.py index a3aa421f..9a5eefb3 100644 --- a/tests/operations/helpers.py +++ b/tests/operations/helpers.py @@ -121,37 +121,37 @@ def build_test_architecture_config( Build the simplest architecture config given the parameters. """ if data_dims == 2: - input_shape = (32, 32) - eval_shape_increase = (8, 8) - downsample_factors = [(2, 2)] + input_shape = (8, 8) + eval_shape_increase = (24, 24) + downsample_factors = [(2, 2)] * 0 upsample_factors = [(2, 2)] * int(upsample) - kernel_size_down = [[(3, 3)] * 2] * 2 - kernel_size_up = [[(3, 3)] * 2] * 1 + kernel_size_down = [[(3, 3)] * 2] * 1 + kernel_size_up = [[(3, 3)] * 2] * 0 kernel_size_down = None # the default should work kernel_size_up = None # the default should work elif data_dims == 3 and architecture_dims == 2: - input_shape = (1, 32, 32) - eval_shape_increase = (15, 8, 8) - downsample_factors = [(1, 2, 2)] + input_shape = (1, 8, 8) + eval_shape_increase = (15, 24, 24) + downsample_factors = [(1, 2, 2)] * 0 # test data upsamples in all dimensions so we have # to here too upsample_factors = [(2, 2, 2)] * int(upsample) # we have to force the 3D kernels to be 2D - kernel_size_down = [[(1, 3, 3)] * 2] * 2 - kernel_size_up = [[(1, 3, 3)] * 2] * 1 + kernel_size_down = [[(1, 3, 3)] * 2] * 1 + kernel_size_up = [[(1, 3, 3)] * 2] * 0 elif data_dims == 3 and architecture_dims == 3: - input_shape = (32, 32, 32) - eval_shape_increase = (8, 8, 8) - downsample_factors = [(2, 2, 2)] + input_shape = (8, 8, 8) + eval_shape_increase = (24, 24, 24) + downsample_factors = [(2, 2, 2)] * 0 upsample_factors = [(2, 2, 2)] * int(upsample) - kernel_size_down = [[(3, 3, 3)] * 2] * 2 - kernel_size_up = [[(3, 3, 3)] * 2] * 1 + kernel_size_down = [[(3, 3, 3)] * 2] * 1 + kernel_size_up = [[(3, 3, 3)] * 2] * 0 kernel_size_down = None # the default should work kernel_size_up = None # the default should work diff --git a/tests/operations/test_mini.py b/tests/operations/test_mini.py index 57afc134..ea60fe57 100644 --- a/tests/operations/test_mini.py +++ b/tests/operations/test_mini.py @@ -30,6 +30,7 @@ @pytest.mark.parametrize("padding", ["valid", "same"]) @pytest.mark.parametrize("func", ["train", "validate"]) @pytest.mark.parametrize("multiprocessing", [True, False]) +@pytest.mark.skip("This test is too slow to run on CI") def test_mini( tmpdir, data_dims,