Skip to content

Commit

Permalink
Train overhaul (#373)
Browse files Browse the repository at this point in the history
Overhaul `Trainer`s.
Removed `Trainer` class and all subclasses, we now only have
`TrainerConfig` and its subclasses.

*breaking api change*
Moved much of the functionality (parallelization, snapshot saving,
batching etc.) from the `TrainerConfig` to the `RunConfig`.
This includes arguments such as `num_workers`, `snapshot_interval`,
`batch_size`. The only thing that must be updated is moving these
arguments from the `TrainerConfig` to the `RunConfig`.
  • Loading branch information
mzouink authored Feb 26, 2025
2 parents 2e149b2 + 28777cd commit 380fbb3
Show file tree
Hide file tree
Showing 61 changed files with 3,163 additions and 4,997 deletions.
14 changes: 7 additions & 7 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def apply(
), "Either validation_dataset and criterion, or iteration must be provided."

# retrieving run
print(f"Loading run {run_name}")
logger.info(f"Loading run {run_name}")
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)
Expand All @@ -102,7 +102,7 @@ def apply(
# load weights
if iteration is None:
iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion) # type: ignore
print(f"Loading weights for iteration {iteration}")
logger.info(f"Loading weights for iteration {iteration}")
weights_store.retrieve_weights(run_name, iteration)

if parameters is None:
Expand All @@ -121,7 +121,7 @@ def apply(
raise ValueError(
"validation_dataset must be a dataset name or a Dataset object, or parameters must be provided explicitly."
)
print(f"Finding best parameters for validation dataset {_validation_dataset}")
logger.info(f"Finding best parameters for validation dataset {_validation_dataset}")
parameters = run.task.evaluator.get_overall_best_parameters(
_validation_dataset, criterion
)
Expand Down Expand Up @@ -183,7 +183,7 @@ def apply(
output_container, f"output_{run_name}_{iteration}_{parameters}"
)

print(
logger.info(
f"Applying best results from run {run.name} at iteration {iteration} to dataset {Path(input_container, input_dataset)}"
)
return apply_run(
Expand Down Expand Up @@ -243,7 +243,7 @@ def apply_run(
... )
"""
# render prediction dataset
print(f"Predicting on dataset {prediction_array_identifier}")
logger.info(f"Predicting on dataset {prediction_array_identifier}")
predict(
run.name,
iteration,
Expand All @@ -257,13 +257,13 @@ def apply_run(
)

# post-process the output
print(
logger.info(
f"Post-processing output to dataset {output_array_identifier}",
output_array_identifier,
)
post_processor = run.task.post_processor
post_processor.set_prediction(prediction_array_identifier)
post_processor.process(parameters, output_array_identifier, num_workers=num_workers)

print("Done")
logger.info("Done")
return
8 changes: 4 additions & 4 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def segment_blockwise(
overwrite=overwrite,
write_size=write_roi.shape,
)
print(
logger.info(
f"Created output array {output_array_identifier.container}:{output_array_identifier.dataset} with ROI {_total_roi}."
)

Expand Down Expand Up @@ -791,7 +791,7 @@ def config():
def generate_dacapo_yaml(config):
with open("dacapo.yaml", "w") as f:
yaml.dump(config.serialize(), f, default_flow_style=False)
print("dacapo.yaml has been created.")
logger.info("dacapo.yaml has been created.")


def generate_config(
Expand Down Expand Up @@ -832,7 +832,7 @@ def unpack_ctx(ctx):
Example:
>>> ctx = ...
>>> kwargs = unpack_ctx(ctx)
>>> print(kwargs)
>>> logger.info(kwargs)
{'arg1': value1, 'arg2': value2, ...}
"""
kwargs = {
Expand All @@ -843,7 +843,7 @@ def unpack_ctx(ctx):
kwargs[k] = int(v)
elif v.replace(".", "").isnumeric():
kwargs[k] = float(v)
print(f"{k}: {kwargs[k]}")
logger.info(f"{k}: {kwargs[k]}")
return kwargs


Expand Down
8 changes: 4 additions & 4 deletions dacapo/experiments/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .architecture import Architecture # noqa
from .architecture_config import ArchitectureConfig # noqa
from .architecture import ArchitectureConfig # noqa
from .dummy_architecture_config import (
DummyArchitectureConfig,
DummyArchitecture,
) # noqa
from .cnnectome_unet_config import CNNectomeUNetConfig, CNNectomeUNet # noqa
from .cnnectome_unet import CNNectomeUNetConfig # noqa
from .wrapped_architecture import WrappedArchitectureConfig # noqa
from .model_zoo_config import ModelZooConfig # noqa
149 changes: 61 additions & 88 deletions dacapo/experiments/architectures/architecture.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,73 @@
from funlib.geometry import Coordinate
import attr

from funlib.geometry import Coordinate
import torch

from pathlib import Path
from abc import ABC, abstractmethod

from bioimageio.spec.model.v0_5 import (
Author,
CiteEntry,
)

class Architecture(torch.nn.Module, ABC):

@attr.s
class ArchitectureConfig(ABC):
"""
An abstract base class for defining the architecture of a neural network model.
It is inherited from PyTorch's Module and built-in class `ABC` (Abstract Base Classes).
Other classes can inherit this class to define their own specific variations of architecture.
It requires to implement several property methods, and also includes additional methods related to the architecture design.
A base class for an configurable architecture that can be used in DaCapo
Attributes:
input_shape (Coordinate): The spatial input shape for the neural network architecture.
eval_shape_increase (Coordinate): The amount to increase the input shape during prediction.
num_in_channels (int): The number of input channels required by the architecture.
num_out_channels (int): The number of output channels provided by the architecture.
name : str
a unique name for the architecture.
Methods:
dims: Returns the number of dimensions of the input shape.
scale: Scales the input voxel size as required by the architecture.
verify()
validates the given architecture.
Note:
The class is abstract and requires to implement the abstract methods.
"""

name: str = attr.ib(
metadata={
"help_text": "A unique name for this architecture. This will be saved so "
"you and others can find and reuse this task. Keep it short "
"and avoid special characters."
}
)

@abstractmethod
def module(self) -> torch.nn.Module:
"""
Returns the `torch.nn.Module` object for a given architecture such that it may be
trained or used for prediction.
"""
pass

@property
@abstractmethod
def input_shape(self) -> Coordinate:
"""
Abstract method to define the spatial input shape for the neural network architecture.
The shape should not account for the channels and batch dimensions.
Returns:
Coordinate: The spatial input shape.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> input_shape = Coordinate((128, 128, 128))
>>> model = MyModel(input_shape)
Note:
The method should be implemented in the derived class.
"""
pass

@property
def dims(self) -> int:
return self.input_shape.dims

@property
def eval_shape_increase(self) -> Coordinate:
"""
Provides information about how much to increase the input shape during prediction.
Returns:
Coordinate: An instance representing the amount to increase in each dimension of the input shape.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> eval_shape_increase = Coordinate((0, 0, 0))
>>> model = MyModel(input_shape, eval_shape_increase)
Note:
The method is optional and can be overridden in the derived class.
"""
return Coordinate((0,) * self.input_shape.dims)
return Coordinate((0,) * self.dims)

@property
@abstractmethod
def num_in_channels(self) -> int:
"""
Abstract method to return number of input channels required by the architecture.
Returns:
int: Required number of input channels.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> num_in_channels = 1
>>> model = MyModel(input_shape, num_in_channels)
Note:
The method should be implemented in the derived class.
"""
pass

Expand All @@ -84,55 +76,36 @@ def num_in_channels(self) -> int:
def num_out_channels(self) -> int:
"""
Abstract method to return the number of output channels provided by the architecture.
Returns:
int: Number of output channels.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> num_out_channels = 1
>>> model = MyModel(input_shape, num_out_channels)
Note:
The method should be implemented in the derived class.
"""
pass

@property
def dims(self) -> int:
"""
Returns the number of dimensions of the input shape.
Returns:
int: The number of dimensions.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> input_shape = Coordinate((128, 128, 128))
>>> model = MyModel(input_shape)
>>> model.dims
3
Note:
The method is optional and can be overridden in the derived class.
"""
return self.input_shape.dims

def scale(self, input_voxel_size: Coordinate) -> Coordinate:
"""
Method to scale the input voxel size as required by the architecture.
Args:
input_voxel_size (Coordinate): The original size of the input voxel.
Returns:
Coordinate: The scaled voxel size.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> input_voxel_size = Coordinate((1, 1, 1))
>>> model = MyModel(input_shape)
>>> model.scale(input_voxel_size)
Coordinate((1, 1, 1))
Note:
The method is optional and can be overridden in the derived class.
"""
return input_voxel_size

def save_bioimage_io_model(
self,
path: Path,
authors: list[Author],
cite: list[CiteEntry] | None = None,
license: str = "MIT",
input_test_image_path: Path | None = None,
output_test_image_path: Path | None = None,
checkpoint: int | str | None = None,
in_voxel_size: Coordinate | None = None,
):
from dacapo.experiments.run_config import RunConfig

run = RunConfig(name=f"{self.name}-bioimage-io", architecture_config=self)
run.save_bioimage_io_model(
path,
authors=authors,
cite=cite,
license=license,
input_test_image_path=input_test_image_path,
output_test_image_path=output_test_image_path,
checkpoint=checkpoint,
in_voxel_size=in_voxel_size,
)
43 changes: 0 additions & 43 deletions dacapo/experiments/architectures/architecture_config.py

This file was deleted.

Loading

0 comments on commit 380fbb3

Please sign in to comment.