Commit 6d58061 1 parent fa51465 commit 6d58061 Copy full SHA for 6d58061
File tree 1 file changed +4
-1
lines changed
1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -96,6 +96,7 @@ class VSUNet(LightningModule):
96
96
:param float lr: learning rate in training, defaults to 1e-3
97
97
:param Literal['WarmupCosine', 'Constant'] schedule:
98
98
learning rate scheduler, defaults to "Constant"
99
+ :param str chkpt_path: path to the checkpoint to load weights, defaults to None
99
100
:param int log_batches_per_epoch:
100
101
number of batches to log each training/validation epoch,
101
102
has to be smaller than steps per epoch, defaults to 8
@@ -146,7 +147,9 @@ def __init__(
146
147
self .log_batches_per_epoch = log_batches_per_epoch
147
148
self .log_samples_per_batch = log_samples_per_batch
148
149
if chkpt_path is not None :
149
- self .model .load_state_dict (torch .load (chkpt_path )["state_dict" ])
150
+ self .model .load_state_dict (
151
+ torch .load (chkpt_path )["state_dict" ], strict = False
152
+ ) # loading only weights
150
153
self .training_step_outputs = []
151
154
self .validation_step_outputs = []
152
155
# required to log the graph
You can’t perform that action at this time.
0 commit comments