diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index 9b75c60ac08..fd444f1389f 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -16,6 +16,7 @@ import torch from ultralytics.engine.model import Model +from ultralytics.utils.downloads import attempt_download_asset from ultralytics.utils.torch_utils import model_info, smart_inference_mode from .predict import NASPredictor @@ -56,7 +57,7 @@ def _load(self, weights: str, task: str): suffix = Path(weights).suffix if suffix == ".pt": - self.model = torch.load(weights) + self.model = torch.load(attempt_download_asset(weights)) elif suffix == "": self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") # Standardize model