Skip to content

Commit 6d58061

Browse files
committed
fixing partial loading
1 parent fa51465 commit 6d58061

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

viscy/light/engine.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class VSUNet(LightningModule):
9696
:param float lr: learning rate in training, defaults to 1e-3
9797
:param Literal['WarmupCosine', 'Constant'] schedule:
9898
learning rate scheduler, defaults to "Constant"
99+
:param str chkpt_path: path to the checkpoint to load weights, defaults to None
99100
:param int log_batches_per_epoch:
100101
number of batches to log each training/validation epoch,
101102
has to be smaller than steps per epoch, defaults to 8
@@ -146,7 +147,9 @@ def __init__(
146147
self.log_batches_per_epoch = log_batches_per_epoch
147148
self.log_samples_per_batch = log_samples_per_batch
148149
if chkpt_path is not None:
149-
self.model.load_state_dict(torch.load(chkpt_path)["state_dict"])
150+
self.model.load_state_dict(
151+
torch.load(chkpt_path)["state_dict"], strict=False
152+
) # loading only weights
150153
self.training_step_outputs = []
151154
self.validation_step_outputs = []
152155
# required to log the graph

0 commit comments

Comments
 (0)