Skip to content

Commit 15802c5

Browse files
authored
Explicitly remove the stride keys from the checkpoint if they are present which should fix the issue with DeciDet checkpoints (#1386)
1 parent 67f7a4e commit 15802c5

File tree

1 file changed

+10
-0
lines changed
  • src/super_gradients/training/models/detection_models

1 file changed

+10
-0
lines changed

src/super_gradients/training/models/detection_models/yolo_base.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import math
23
import warnings
34
from typing import Union, Type, List, Tuple, Optional
@@ -590,6 +591,15 @@ def forward(self, x):
590591

591592
def load_state_dict(self, state_dict, strict=True):
592593
try:
594+
keys_dropped_in_sg_320 = {
595+
"stride",
596+
"_head.anchors._stride",
597+
"_head.anchors._anchors",
598+
"_head.anchors._anchor_grid",
599+
"_head._modules_list.14.stride",
600+
}
601+
state_dict = collections.OrderedDict([(k, v) for k, v in state_dict.items() if k not in keys_dropped_in_sg_320])
602+
593603
super().load_state_dict(state_dict, strict)
594604
except RuntimeError as e:
595605
raise RuntimeError(

0 commit comments

Comments
 (0)