Skip to content

Commit

Permalink
Add support for saving a list of iterations on snapshot arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Feb 18, 2025
1 parent f5c26fd commit bd84157
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
1 change: 0 additions & 1 deletion dacapo/experiments/datasplits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@
from .dummy_datasplit_config import DummyDataSplitConfig
from .train_validate_datasplit import TrainValidateDataSplit
from .train_validate_datasplit_config import TrainValidateDataSplitConfig
from .datasplit_generator import DataSplitGenerator, DatasetSpec
from .simple_config import SimpleDataSplitConfig
16 changes: 7 additions & 9 deletions dacapo/experiments/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,17 +591,13 @@ def resume_training(self, stats_store, weights_store) -> int:
trained_until = latest_weights_iteration
self.training_stats.delete_after(trained_until)
self.validation_scores.delete_after(trained_until)
weights = weights_store.retrieve_weights(
self, iteration=trained_until
)
weights = weights_store.retrieve_weights(self, iteration=trained_until)

# perfectly in sync. We can continue training
elif latest_weights_iteration == trained_until:
logger.info(f"Resuming training from iteration {trained_until}")

weights = weights_store.retrieve_weights(
self, iteration=trained_until
)
weights = weights_store.retrieve_weights(self, iteration=trained_until)

# weights are stored past the stored training stats, log this inconsistency
# but keep training
Expand Down Expand Up @@ -693,7 +689,7 @@ def save_snapshot(
shape=(0, *v.shape),
offset=v.roi.offset,
voxel_size=v.voxel_size,
axis_names=("iteration^", *v.axis_names),
axis_names=("iterations^", *v.axis_names),
dtype=v.dtype if v.dtype != bool else np.uint8,
mode="w",
)
Expand All @@ -708,8 +704,10 @@ def save_snapshot(

# add an extra dimension so that the shapes match
array._source_data.append(data[None, :])
iterations = array.attrs.setdefault("iterations", list())
iterations.append(iteration)
if "iterations" not in array.attrs:
print(f"No iterations found in array at iteration {iteration}")
array.attrs["iterations"] = list()
array.attrs["iterations"] = array.attrs["iterations"] + [iteration]


def from_yaml(config_yaml: dict) -> torch.nn.Module:
Expand Down
6 changes: 5 additions & 1 deletion tests/operations/test_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_mini(
trainer_config=trainer_config,
datasplit_config=data_config,
repetition=0,
num_iterations=1,
num_iterations=2,
snapshot_interval=1,
batch_size=2,
num_workers=int(multiprocessing),
Expand All @@ -99,5 +99,9 @@ def test_mini(
"volumes/mask",
]
)
iterations_list = zarr.open(snapshot_container)["volumes/raw"].attrs[
"iterations"
]
assert iterations_list == [0, 1], iterations_list
elif func == "validate":
validate_run(run, 1)

0 comments on commit bd84157

Please sign in to comment.