Skip to content

Commit 01c71cf

Browse files
committed
custom head
1 parent e18d305 commit 01c71cf

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

tests/unet/test_fcmae.py

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from viscy.unet.networks.fcmae import (
44
FullyConvolutionalMAE,
5+
PixelToVoxelShuffleHead,
56
MaskedAdaptiveProjection,
67
MaskedConvNeXtV2Block,
78
MaskedConvNeXtV2Stage,
@@ -104,6 +105,13 @@ def test_masked_multiscale_encoder():
104105
assert afeat.shape[2] == afeat.shape[3] == xy_size // stride
105106

106107

108+
def test_pixel_to_voxel_shuffle_head():
109+
head = PixelToVoxelShuffleHead(240, 3, out_stack_depth=5, xy_scaling=4)
110+
x = torch.rand(2, 240, 16, 16)
111+
y = head(x)
112+
assert y.shape == (2, 3, 5, 64, 64)
113+
114+
107115
def test_fcmae():
108116
x = torch.rand(2, 3, 5, 128, 128)
109117
model = FullyConvolutionalMAE(3, 3)

viscy/unet/networks/fcmae.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
and timm's dense implementation of the encoder in ``timm.models.convnext``
66
"""
77

8+
import math
89
from typing import Sequence
910

1011
import torch
@@ -416,7 +417,7 @@ def __init__(
416417
)
417418
decoder_channels = list(dims)
418419
decoder_channels.reverse()
419-
decoder_channels[-1] = (in_stack_depth + 2) * in_channels * 2**2
420+
decoder_channels[-1] = out_channels * in_stack_depth * stem_kernel_size[-1] ** 2
420421
self.decoder = Unet2dDecoder(
421422
decoder_channels,
422423
norm_name="instance",
@@ -433,7 +434,8 @@ def __init__(
433434
pool=True,
434435
)
435436
self.out_stack_depth = in_stack_depth
436-
self.num_blocks = 6
437+
# TODO: replace num_blocks with explicit strides for all models
438+
self.num_blocks = len(dims) * int(math.log2(stem_kernel_size[-1]))
437439
self.pretraining = pretraining
438440

439441
def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor:

0 commit comments

Comments
 (0)