From f98fe7bedaef955418df80eaafaf4f058b4eb218 Mon Sep 17 00:00:00 2001 From: Arthur Thuy <57416568+arthur-thuy@users.noreply.github.com> Date: Tue, 12 Sep 2023 22:27:29 +0200 Subject: [PATCH] fix: weights=None in vgg16 (#276) --- experiments/ssl_experiments/pimodel_cifar10.py | 2 +- experiments/ssl_experiments/pimodel_mcdropout_cifar10.py | 2 +- experiments/vgg_mcdropout_cifar10.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/experiments/ssl_experiments/pimodel_cifar10.py b/experiments/ssl_experiments/pimodel_cifar10.py index 0ce7f349..4f7e8d2b 100644 --- a/experiments/ssl_experiments/pimodel_cifar10.py +++ b/experiments/ssl_experiments/pimodel_cifar10.py @@ -289,7 +289,7 @@ def add_model_specific_args(parent_parser): print("Active set length: {}".format(len(active_set))) print("Pool set length: {}".format(len(active_set.pool))) - net = vgg11(pretrained=False, num_classes=10) + net = vgg11(weights=None, num_classes=10) weights = load_state_dict_from_url("https://download.pytorch.org/models/vgg11-bbd30ac9.pth") weights = {k: v for k, v in weights.items() if "classifier.6" not in k} diff --git a/experiments/ssl_experiments/pimodel_mcdropout_cifar10.py b/experiments/ssl_experiments/pimodel_mcdropout_cifar10.py index 5da5c146..7f3ae2ad 100644 --- a/experiments/ssl_experiments/pimodel_mcdropout_cifar10.py +++ b/experiments/ssl_experiments/pimodel_mcdropout_cifar10.py @@ -82,7 +82,7 @@ def add_model_specific_args(parent_parser): print("Pool set length: {}".format(len(active_set.pool))) heuristic = get_heuristic(params.heuristic) - model = vgg16(pretrained=False, num_classes=10) + model = vgg16(weights=None, num_classes=10) weights = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth") weights = {k: v for k, v in weights.items() if "classifier.6" not in k} model.load_state_dict(weights, strict=False) diff --git a/experiments/vgg_mcdropout_cifar10.py b/experiments/vgg_mcdropout_cifar10.py index ad67a0d0..d4bdec8d 100644 --- a/experiments/vgg_mcdropout_cifar10.py +++ b/experiments/vgg_mcdropout_cifar10.py @@ -84,7 +84,7 @@ def main(): heuristic = get_heuristic(hyperparams["heuristic"], hyperparams["shuffle_prop"]) criterion = CrossEntropyLoss() - model = vgg16(pretrained=False, num_classes=10) + model = vgg16(weights=None, num_classes=10) weights = load_state_dict_from_url("https://download.pytorch.org/models/vgg16-397923af.pth") weights = {k: v for k, v in weights.items() if "classifier.6" not in k} model.load_state_dict(weights, strict=False)