@@ -1517,16 +1517,7 @@ def load_checkpoint_to_model(
1517
1517
message_model = "model" if not load_backbone else "model's backbone"
1518
1518
logger .info ("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix )
1519
1519
1520
- if (isinstance (net , HasPredict )) and load_processing_params :
1521
- if "processing_params" not in checkpoint .keys ():
1522
- raise ValueError ("Can't load processing params - could not find any stored in checkpoint file." )
1523
- try :
1524
- net .set_dataset_processing_params (** checkpoint ["processing_params" ])
1525
- except Exception as e :
1526
- logger .warning (
1527
- f"Could not set preprocessing pipeline from the checkpoint dataset: { e } . Before calling"
1528
- "predict make sure to call set_dataset_processing_params."
1529
- )
1520
+ _maybe_load_preprocessing_params (net , checkpoint )
1530
1521
1531
1522
if load_weights_only or load_backbone :
1532
1523
# DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
@@ -1549,10 +1540,12 @@ def __init__(self, desc):
1549
1540
def load_pretrained_weights (model : torch .nn .Module , architecture : str , pretrained_weights : str ):
1550
1541
"""
1551
1542
Loads pretrained weights from the MODEL_URLS dictionary to model
1552
- :param architecture: name of the model's architecture
1553
- :param model: model to load pretrinaed weights for
1554
- :param pretrained_weights: name for the pretrianed weights (i.e imagenet)
1555
- :return: None
1543
+
1544
+ :param architecture: name of the model's architecture
1545
+ :param model: model to load pretrinaed weights for
1546
+ :param pretrained_weights: name for the pretrianed weights (i.e imagenet)
1547
+
1548
+ :return: None
1556
1549
"""
1557
1550
from super_gradients .common .object_names import Models
1558
1551
@@ -1569,19 +1562,19 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
1569
1562
"By downloading the pre-trained weight files you agree to comply with these terms."
1570
1563
)
1571
1564
1572
- unique_filename = url .split ("https://sghub.deci.ai/models/" )[1 ].replace ("/" , "_" ).replace (" " , "_" )
1573
- map_location = torch .device ("cpu" )
1574
- with wait_for_the_master (get_local_rank ()):
1575
- pretrained_state_dict = load_state_dict_from_url (url = url , map_location = map_location , file_name = unique_filename )
1576
- _load_weights (architecture , model , pretrained_state_dict )
1577
-
1565
+ # Basically this check allows settings pretrained weights from local path using file:///path/to/weights scheme
1566
+ # which is a valid URI scheme for local files
1567
+ # Supporting local files and file URI allows us modification of pretrained weights dics in unit tests
1568
+ if url .startswith ("file://" ) or os .path .exists (url ):
1569
+ pretrained_state_dict = torch .load (url .replace ("file://" , "" ), map_location = "cpu" )
1570
+ else :
1571
+ unique_filename = url .split ("https://sghub.deci.ai/models/" )[1 ].replace ("/" , "_" ).replace (" " , "_" )
1572
+ map_location = torch .device ("cpu" )
1573
+ with wait_for_the_master (get_local_rank ()):
1574
+ pretrained_state_dict = load_state_dict_from_url (url = url , map_location = map_location , file_name = unique_filename )
1578
1575
1579
- def _load_weights (architecture , model , pretrained_state_dict ):
1580
- if "ema_net" in pretrained_state_dict .keys ():
1581
- pretrained_state_dict ["net" ] = pretrained_state_dict ["ema_net" ]
1582
- solver = YoloXCheckpointSolver () if "yolox" in architecture else DefaultCheckpointSolver ()
1583
- adaptive_load_state_dict (net = model , state_dict = pretrained_state_dict , strict = StrictLoad .NO_KEY_MATCHING , solver = solver )
1584
- logger .info (f"Successfully loaded pretrained weights for architecture { architecture } " )
1576
+ _load_weights (architecture , model , pretrained_state_dict )
1577
+ _maybe_load_preprocessing_params (model , pretrained_state_dict )
1585
1578
1586
1579
1587
1580
def load_pretrained_weights_local (model : torch .nn .Module , architecture : str , pretrained_weights : str ):
@@ -1598,3 +1591,41 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre
1598
1591
1599
1592
pretrained_state_dict = torch .load (pretrained_weights , map_location = map_location )
1600
1593
_load_weights (architecture , model , pretrained_state_dict )
1594
+ _maybe_load_preprocessing_params (model , pretrained_state_dict )
1595
+
1596
+
1597
+ def _load_weights (architecture , model , pretrained_state_dict ):
1598
+ if "ema_net" in pretrained_state_dict .keys ():
1599
+ pretrained_state_dict ["net" ] = pretrained_state_dict ["ema_net" ]
1600
+ solver = YoloXCheckpointSolver () if "yolox" in architecture else DefaultCheckpointSolver ()
1601
+ adaptive_load_state_dict (net = model , state_dict = pretrained_state_dict , strict = StrictLoad .NO_KEY_MATCHING , solver = solver )
1602
+ logger .info (f"Successfully loaded pretrained weights for architecture { architecture } " )
1603
+
1604
+
1605
+ def _maybe_load_preprocessing_params (model : Union [nn .Module , HasPredict ], checkpoint : Mapping [str , Tensor ]) -> bool :
1606
+ """
1607
+ Tries to load preprocessing params from the checkpoint to the model.
1608
+ The function does not crash, and raises a warning if the loading fails.
1609
+ :param model: Instance of nn.Module
1610
+ :param checkpoint: Entire checkpoint dict (not state_dict with model weights)
1611
+ :return: True if the loading was successful, False otherwise.
1612
+ """
1613
+ model = unwrap_model (model )
1614
+ checkpoint_has_preprocessing_params = "processing_params" in checkpoint .keys ()
1615
+ model_has_predict = isinstance (model , HasPredict )
1616
+ logger .debug (
1617
+ f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: { checkpoint_has_preprocessing_params } . "
1618
+ f"Model { model .__class__ .__name__ } inherit HasPredict: { model_has_predict } "
1619
+ )
1620
+
1621
+ if model_has_predict and checkpoint_has_preprocessing_params :
1622
+ try :
1623
+ model .set_dataset_processing_params (** checkpoint ["processing_params" ])
1624
+ logger .debug (f"Successfully loaded preprocessing params from checkpoint { checkpoint ['processing_params' ]} " )
1625
+ return True
1626
+ except Exception as e :
1627
+ logger .warning (
1628
+ f"Could not set preprocessing pipeline from the checkpoint dataset: { e } . Before calling"
1629
+ "predict make sure to call set_dataset_processing_params."
1630
+ )
1631
+ return False
0 commit comments