Skip to content

Commit

Permalink
black reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Feb 18, 2025
1 parent d6d95ab commit 9987e51
Show file tree
Hide file tree
Showing 14 changed files with 40 additions and 36 deletions.
3 changes: 1 addition & 2 deletions dacapo/experiments/architectures/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def module(self) -> torch.nn.Module:
"""
pass


@property
@abstractmethod
def input_shape(self) -> Coordinate:
Expand Down Expand Up @@ -78,4 +77,4 @@ def scale(self, input_voxel_size: Coordinate) -> Coordinate:
"""
Method to scale the input voxel size as required by the architecture.
"""
return input_voxel_size
return input_voxel_size
11 changes: 5 additions & 6 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,15 @@ class CNNectomeUNetConfig(ArchitectureConfig):
default=False,
metadata={"help_text": "Whether to use batch normalization."},
)

@property
def input_shape(self) -> Coordinate:
return self._input_shape

@property
def input_shape(self) -> Coordinate:
return Coordinate(self._input_shape)

@input_shape.setter
def input_shape(self, value: Coordinate):
self._input_shape = value
Expand All @@ -147,15 +150,11 @@ def module(self) -> torch.nn.Module:
if self.kernel_size_down is not None:
kernel_size_down = self.kernel_size_down
else:
kernel_size_down = [
[(3,) * self.dims, (3,) * self.dims]
] * levels
kernel_size_down = [[(3,) * self.dims, (3,) * self.dims]] * levels
if self.kernel_size_up is not None:
kernel_size_up = self.kernel_size_up
else:
kernel_size_up = [
[(3,) * self.dims, (3,) * self.dims]
] * (levels - 1)
kernel_size_up = [[(3,) * self.dims, (3,) * self.dims]] * (levels - 1)

# downsample factors has to be a list of tuples
downsample_factors = [tuple(x) for x in self.downsample_factors]
Expand Down
1 change: 1 addition & 0 deletions dacapo/experiments/architectures/cnnectome_unet_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import math


class CNNectomeUNetModule(torch.nn.Module):
"""
A U-Net module for 3D or 4D data. The U-Net expects 3D or 4D tensors shaped
Expand Down
24 changes: 12 additions & 12 deletions dacapo/experiments/architectures/model_zoo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ def model_description(self) -> ModelDescr:

@property
def input_desc(self) -> InputTensorDescr:
assert len(self.model_description.inputs) == 1, (
f"Only models with one input are supported, found {self.model_description.inputs}"
)
assert (
len(self.model_description.inputs) == 1
), f"Only models with one input are supported, found {self.model_description.inputs}"
return self.model_description.inputs[0]

@property
def output_desc(self) -> OutputTensorDescr:
assert len(self.model_description.outputs) == 1, (
f"Only models with one output are supported, found {self.model_description.outputs}"
)
assert (
len(self.model_description.outputs) == 1
), f"Only models with one output are supported, found {self.model_description.outputs}"
return self.model_description.outputs[0]

@property
Expand All @@ -115,19 +115,19 @@ def input_shape(self):
@property
def num_in_channels(self) -> int:
channel_axes = [axis for axis in self.input_desc.axes if axis.type == "channel"]
assert len(channel_axes) == 1, (
f"Only models with one input channel axis are supported, found {channel_axes}"
)
assert (
len(channel_axes) == 1
), f"Only models with one input channel axis are supported, found {channel_axes}"
return channel_axes[0].size

@property
def num_out_channels(self) -> int:
channel_axes = [
axis for axis in self.output_desc.axes if axis.type == "channel"
]
assert len(channel_axes) == 1, (
f"Only models with one output channel axis are supported, found {channel_axes}"
)
assert (
len(channel_axes) == 1
), f"Only models with one output channel axis are supported, found {channel_axes}"
return channel_axes[0].size

def scale(self, input_voxel_size: Coordinate) -> Coordinate:
Expand Down
6 changes: 4 additions & 2 deletions dacapo/experiments/tasks/affinities_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def __init__(self, task_config):
self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood)
self.evaluator = InstanceEvaluator()

self._channels = [f"aff_{'.'.join(map(str, n))}" for n in task_config.neighborhood]
self._channels = [
f"aff_{'.'.join(map(str, n))}" for n in task_config.neighborhood
]

@property
def channels(self) -> list[str]:
return self._channels
return self._channels
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ def __init__(self, task_config):

@property
def channels(self) -> list[str]:
return self._channels
return self._channels
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/dummy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def __init__(self, task_config):

@property
def channels(self) -> list[str]:
return self._channels
return self._channels
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def __init__(self, task_config):

@property
def channels(self) -> list[str]:
return self._channels
return self._channels
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/inner_distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def __init__(self, task_config):

@property
def channels(self) -> list[str]:
return self._channels
return self._channels
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/one_hot_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def __init__(self, task_config):

@property
def channels(self) -> list[str]:
return self._classes
return self._classes
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def evaluation_scores(self) -> EvaluationScores:

def create_model(self, architecture):
return self.predictor.create_model(architecture=architecture)

@property
@abstractmethod
def channels(self) -> list[str]:
Expand Down
7 changes: 5 additions & 2 deletions dacapo/store/local_weights_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ def save_trace(self, run: RunConfig):
if not trace_file.parent.exists():
trace_file.parent.mkdir(parents=True, exist_ok=True)
if not trace_file.exists():
in_shape = (1, run.architecture.num_in_channels, *run.architecture.input_shape)
in_shape = (
1,
run.architecture.num_in_channels,
*run.architecture.input_shape,
)
in_data = torch.randn(in_shape)
try:
torch.jit.save(
Expand All @@ -82,7 +86,6 @@ def save_trace(self, run: RunConfig):
except SystemError as e:
print(f"Error saving trace: {e}, this model will not be traced")
trace_file.touch()


def latest_iteration(self, run: str) -> Optional[int]:
"""
Expand Down
8 changes: 5 additions & 3 deletions tests/components/test_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,15 @@ def build_test_architecture_config(
tmp_path / "dacapo_modelzoo_test.zip", authors=[Author(name="Test")]
)
return ModelZooConfig(
model_id = tmp_path / "dacapo_modelzoo_test.zip", name="test_model_zoo"
model_id=tmp_path / "dacapo_modelzoo_test.zip", name="test_model_zoo"
)


# TODO: Move unet parameters that don't affect interaction with other modules
# to a separate architcture test
@pytest.mark.filterwarnings("ignore::FutureWarning") # pytest treats this as an error but we don't care for now
@pytest.mark.filterwarnings(
"ignore::FutureWarning"
) # pytest treats this as an error but we don't care for now
@pytest.mark.parametrize("data_dims", [2, 3])
@pytest.mark.parametrize("channels", [True, False])
@pytest.mark.parametrize("architecture_dims", [2, 3])
Expand Down Expand Up @@ -155,7 +157,7 @@ def test_architectures(
use_attention,
padding,
source,
tmp_path
tmp_path,
)

in_data = torch.rand(
Expand Down
4 changes: 1 addition & 3 deletions tests/operations/test_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def test_stored_architecture(

architecture = retrieved_arch_config

assert (
architecture.dims is not None
), f"Architecture dims are None {architecture}"
assert architecture.dims is not None, f"Architecture dims are None {architecture}"


@pytest.mark.parametrize(
Expand Down

0 comments on commit 9987e51

Please sign in to comment.