Skip to content

Commit

Permalink
save a list of iterations that are saved in a snapshot array
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Feb 18, 2025
1 parent 61f3e3f commit a9c6eda
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
9 changes: 6 additions & 3 deletions dacapo/experiments/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def lr_scheduler(self) -> torch.optim.lr_scheduler.LRScheduler:
self.optimizer,
start_factor=0.01,
end_factor=1.0,
total_iters=10,
total_iters=self.num_iterations // self.validation_interval,
last_epoch=-1,
)
return self._lr_scheduler
Expand Down Expand Up @@ -597,7 +597,7 @@ def resume_training(self, stats_store, weights_store) -> int:

# perfectly in sync. We can continue training
elif latest_weights_iteration == trained_until:
print(f"Resuming training from epoch {trained_until}")
print(f"Resuming training from iteration {trained_until}")

weights = weights_store.retrieve_weights(
self, iteration=trained_until
Expand Down Expand Up @@ -645,6 +645,7 @@ def train_step(self, raw: torch.Tensor, target: torch.Tensor, weight: torch.Tens

def save_snapshot(
self,
iteration: int,
batch: dict[str, torch.Tensor],
batch_out: dict[str, torch.Tensor],
snapshot_container: LocalContainerIdentifier,
Expand Down Expand Up @@ -696,7 +697,7 @@ def save_snapshot(
shape=(0, *v.shape),
offset=v.roi.offset,
voxel_size=v.voxel_size,
axis_names=("epoch^", *v.axis_names),
axis_names=("iteration^", *v.axis_names),
dtype=v.dtype if v.dtype != bool else np.uint8,
mode="w",
)
Expand All @@ -711,6 +712,8 @@ def save_snapshot(

# add an extra dimension so that the shapes match
array._source_data.append(data[None, :])
iterations = array.attrs.setdefault("iterations", list())
iterations.append(iteration)


def from_yaml(config_yaml: dict) -> torch.nn.Module:
Expand Down
4 changes: 2 additions & 2 deletions dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def train_run(run: RunConfig, validate: bool = True, save_snapshots: bool = Fals
# save snapshot. We save the snapshot at the start of every
# {snapshot_interval} iterations. This is for debugging
# purposes so you get snapshots quickly.
run.save_snapshot(batch, batch_out, snapshot_container)
run.save_snapshot(i, batch, batch_out, snapshot_container)

if i % run.validation_interval == run.validation_interval - 1 or i == run.num_iterations - 1:
# run "end of epoch steps" such as stepping the learning rate
Expand All @@ -122,7 +122,7 @@ def train_run(run: RunConfig, validate: bool = True, save_snapshots: bool = Fals
logger.warning(w)
pass

# Store epoch checkpoint and training stats
# Store checkpoint and training stats
stats_store.store_training_stats(run.name, run.training_stats)
weights_store.store_weights(run, i + 1)

Expand Down

0 comments on commit a9c6eda

Please sign in to comment.