Skip to content

Commit

Permalink
Overhaul Trainers:
Browse files Browse the repository at this point in the history
- consolidate `Trainer` and `TrainerConfig` classes
- move snapshots/num_workers/batch_size etc. to `RunConfig`
- make interface a simple `torch.utils.data.IterableDataset`.
- update `train.py` and `validate.py` to use overhauled `Trainer` classes
  • Loading branch information
pattonw committed Feb 18, 2025
1 parent a731f38 commit d523f36
Show file tree
Hide file tree
Showing 11 changed files with 870 additions and 1,312 deletions.
327 changes: 281 additions & 46 deletions dacapo/experiments/run_config.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions dacapo/experiments/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .trainer import Trainer # noqa
from .trainer_config import TrainerConfig # noqa
from .dummy_trainer_config import DummyTrainerConfig, DummyTrainer # noqa
from .gunpowder_trainer_config import GunpowderTrainerConfig, GunpowderTrainer # noqa
from .dummy_trainer_config import DummyTrainerConfig # noqa
from .gunpowder_trainer_config import GunpowderTrainerConfig # noqa
from .gp_augments import AugmentConfig # noqa
171 changes: 0 additions & 171 deletions dacapo/experiments/trainers/dummy_trainer.py

This file was deleted.

59 changes: 46 additions & 13 deletions dacapo/experiments/trainers/dummy_trainer_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
import attr

from .dummy_trainer import DummyTrainer
from dacapo.experiments.datasplits.datasets import DatasetConfig
from dacapo.experiments.tasks.predictors import Predictor
from .trainer_config import TrainerConfig

from funlib.geometry import Roi, Coordinate

from typing import Tuple

import numpy as np
import torch


class GeneratorDataset(torch.utils.data.IterableDataset):
"""
Helper class to return a torch IterableDataset from a generator
"""

def __init__(self, generator, *args, **kwargs):
self.generator = generator
self.args = args
self.kwargs = kwargs

def __iter__(self):
return self.generator(*self.args, **self.kwargs)


@attr.s
class DummyTrainerConfig(TrainerConfig):
Expand All @@ -21,18 +41,31 @@ class DummyTrainerConfig(TrainerConfig):
"""

trainer_type = DummyTrainer
dummy_attr: bool = attr.ib(metadata={"help_text": "Dummy attribute."})

mirror_augment: bool = attr.ib(metadata={"help_text": "Dummy attribute."})
def iterable_dataset(
self,
datasets: list[DatasetConfig],
input_shape: Coordinate,
output_shape: Coordinate,
predictor: Predictor | None = None,
):
in_roi = Roi(input_shape * 0, input_shape)
out_roi = Roi(output_shape * 0, output_shape)
in_voxel_size = datasets[0].raw.voxel_size
raw = torch.from_numpy(
datasets[0].raw[in_roi * in_voxel_size].astype(np.float32)
)
out_raw = torch.from_numpy(
datasets[0].raw[out_roi * in_voxel_size].astype(np.float32)
)

def verify(self) -> Tuple[bool, str]:
"""
Verify the DummyTrainerConfig object.
def generator():
while True:
yield {
"raw": raw,
"target": out_raw,
"weight": torch.ones_like(out_raw),
}

Returns:
Tuple[bool, str]: A tuple containing a boolean value indicating whether the DummyTrainerConfig object is valid
and a string containing the reason why the object is invalid.
Examples:
>>> valid, reason = trainer_config.verify()
"""
return False, "This is a DummyTrainerConfig and is never valid"
return GeneratorDataset(generator)
Loading

0 comments on commit d523f36

Please sign in to comment.