From 25980f8718958e6ef0d4e81cd1ff7b853004d9d2 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sun, 27 Nov 2022 13:41:56 +0100 Subject: [PATCH 1/7] name change --- conf/bigearthnet.yaml | 2 +- conf/eurosat.yaml | 2 +- conf/resisc45.yaml | 2 +- conf/so2sat.yaml | 2 +- conf/ucmerced.yaml | 2 +- evaluate.py | 4 ++-- experiments/run_resisc45_experiments.py | 2 +- experiments/run_so2sat_byol_experiments.py | 2 +- experiments/run_so2sat_experiments.py | 2 +- experiments/run_so2sat_seed_experiments.py | 2 +- tests/conf/bigearthnet_all.yaml | 2 +- tests/conf/bigearthnet_s1.yaml | 2 +- tests/conf/bigearthnet_s2.yaml | 2 +- tests/conf/eurosat.yaml | 2 +- tests/conf/resisc45.yaml | 2 +- tests/conf/so2sat_supervised.yaml | 2 +- tests/conf/so2sat_unsupervised.yaml | 2 +- tests/conf/ucmerced.yaml | 2 +- tests/trainers/conftest.py | 4 ++-- tests/trainers/test_classification.py | 8 ++++---- tests/trainers/test_utils.py | 5 +---- torchgeo/trainers/classification.py | 24 +++++++++++++--------- torchgeo/trainers/utils.py | 10 ++++----- 23 files changed, 44 insertions(+), 45 deletions(-) diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml index de376238bd8..81d0e83893f 100644 --- a/conf/bigearthnet.yaml +++ b/conf/bigearthnet.yaml @@ -7,7 +7,7 @@ experiment: task: "bigearthnet" module: loss: "bce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/conf/eurosat.yaml b/conf/eurosat.yaml index 4f046aa4847..5abde6d4592 100644 --- a/conf/eurosat.yaml +++ b/conf/eurosat.yaml @@ -2,7 +2,7 @@ experiment: task: "eurosat" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/conf/resisc45.yaml b/conf/resisc45.yaml index f8c70e9961a..4dc34b13c0a 100644 --- a/conf/resisc45.yaml +++ b/conf/resisc45.yaml @@ -7,7 +7,7 @@ experiment: task: "resisc45" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/conf/so2sat.yaml b/conf/so2sat.yaml index c54259004c1..4caf2a01ba5 100644 --- a/conf/so2sat.yaml +++ b/conf/so2sat.yaml @@ -7,7 +7,7 @@ experiment: task: "so2sat" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/conf/ucmerced.yaml b/conf/ucmerced.yaml index 9975a7c4821..4ab6612d1ae 100644 --- a/conf/ucmerced.yaml +++ b/conf/ucmerced.yaml @@ -2,7 +2,7 @@ experiment: task: "ucmerced" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" weights: null learning_rate: 1e-3 learning_rate_schedule_patience: 6 diff --git a/evaluate.py b/evaluate.py index 01358052c9d..f947382ec0f 100755 --- a/evaluate.py +++ b/evaluate.py @@ -161,7 +161,7 @@ def main(args: argparse.Namespace) -> None: if issubclass(TASK, ClassificationTask): val_row: Dict[str, Union[str, float]] = { "split": "val", - "classification_model": model.hparams["classification_model"], + "model": model.hparams["model"], "learning_rate": model.hparams["learning_rate"], "weights": model.hparams["weights"], "loss": model.hparams["loss"], @@ -169,7 +169,7 @@ def main(args: argparse.Namespace) -> None: test_row: Dict[str, Union[str, float]] = { "split": "test", - "classification_model": model.hparams["classification_model"], + "model": model.hparams["model"], "learning_rate": model.hparams["learning_rate"], "weights": model.hparams["weights"], "loss": model.hparams["loss"], diff --git a/experiments/run_resisc45_experiments.py b/experiments/run_resisc45_experiments.py index c149a630532..bd7f79a013a 100755 --- a/experiments/run_resisc45_experiments.py +++ b/experiments/run_resisc45_experiments.py @@ -50,7 +50,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: "python train.py" + f" config_file={config_file}" + f" experiment.name={experiment_name}" - + f" experiment.module.classification_model={model}" + + f" experiment.module.model={model}" + f" experiment.module.learning_rate={lr}" + f" experiment.module.loss={loss}" + f" experiment.module.weights={weights}" diff --git a/experiments/run_so2sat_byol_experiments.py b/experiments/run_so2sat_byol_experiments.py index fe72b59a700..e0b132f7d9c 100755 --- a/experiments/run_so2sat_byol_experiments.py +++ b/experiments/run_so2sat_byol_experiments.py @@ -52,7 +52,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: "python train.py" + f" config_file={config_file}" + f" experiment.name={experiment_name}" - + f" experiment.module.classification_model={model}" + + f" experiment.module.model={model}" + f" experiment.module.learning_rate={lr}" + f" experiment.module.loss={loss}" + f" experiment.module.weights={weights}" diff --git a/experiments/run_so2sat_experiments.py b/experiments/run_so2sat_experiments.py index bdc7e5a7a30..6d1ea0f8eae 100755 --- a/experiments/run_so2sat_experiments.py +++ b/experiments/run_so2sat_experiments.py @@ -50,7 +50,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: "python train.py" + f" config_file={config_file}" + f" experiment.name={experiment_name}" - + f" experiment.module.classification_model={model}" + + f" experiment.module.model={model}" + f" experiment.module.learning_rate={lr}" + f" experiment.module.loss={loss}" + f" experiment.module.weights={weights}" diff --git a/experiments/run_so2sat_seed_experiments.py b/experiments/run_so2sat_seed_experiments.py index 5585e07a1d3..4f861926801 100755 --- a/experiments/run_so2sat_seed_experiments.py +++ b/experiments/run_so2sat_seed_experiments.py @@ -51,7 +51,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: "python train.py" + f" config_file={config_file}" + f" experiment.name={experiment_name}" - + f" experiment.module.classification_model={model}" + + f" experiment.module.model={model}" + f" experiment.module.learning_rate={lr}" + f" experiment.module.loss={loss}" + f" experiment.module.weights={weights}" diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml index e6534c888a5..ed9c68d39a0 100644 --- a/tests/conf/bigearthnet_all.yaml +++ b/tests/conf/bigearthnet_all.yaml @@ -2,7 +2,7 @@ experiment: task: "bigearthnet" module: loss: "bce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml index 720dd95109a..6c16bb4e7e0 100644 --- a/tests/conf/bigearthnet_s1.yaml +++ b/tests/conf/bigearthnet_s1.yaml @@ -2,7 +2,7 @@ experiment: task: "bigearthnet" module: loss: "bce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml index 78ea30eb91a..74876350e8f 100644 --- a/tests/conf/bigearthnet_s2.yaml +++ b/tests/conf/bigearthnet_s2.yaml @@ -2,7 +2,7 @@ experiment: task: "bigearthnet" module: loss: "bce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/eurosat.yaml b/tests/conf/eurosat.yaml index 96cb403e962..e865c7af8e9 100644 --- a/tests/conf/eurosat.yaml +++ b/tests/conf/eurosat.yaml @@ -2,7 +2,7 @@ experiment: task: "eurosat" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/resisc45.yaml b/tests/conf/resisc45.yaml index 11482c1b677..89f7b8072c4 100644 --- a/tests/conf/resisc45.yaml +++ b/tests/conf/resisc45.yaml @@ -2,7 +2,7 @@ experiment: task: "resisc45" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/so2sat_supervised.yaml b/tests/conf/so2sat_supervised.yaml index f85801fa81a..476644ffe59 100644 --- a/tests/conf/so2sat_supervised.yaml +++ b/tests/conf/so2sat_supervised.yaml @@ -2,7 +2,7 @@ experiment: task: "so2sat" module: loss: "focal" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/so2sat_unsupervised.yaml b/tests/conf/so2sat_unsupervised.yaml index 7f6854dc8f1..e7aeda2547d 100644 --- a/tests/conf/so2sat_unsupervised.yaml +++ b/tests/conf/so2sat_unsupervised.yaml @@ -2,7 +2,7 @@ experiment: task: "so2sat" module: loss: "jaccard" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index 892e049aa02..fe39f579c35 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -2,7 +2,7 @@ experiment: task: "ucmerced" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" weights: "random" learning_rate: 1e-3 learning_rate_schedule_patience: 6 diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index ad708afa246..792703cdaab 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -31,11 +31,11 @@ def state_dict(model: Module) -> Dict[str, Tensor]: return model.state_dict() -@pytest.fixture(params=["classification_model", "encoder_name"]) +@pytest.fixture(params=["model", "encoder_name"]) def checkpoint( state_dict: Dict[str, Tensor], request: SubRequest, tmp_path: Path ) -> str: - if request.param == "classification_model": + if request.param == "model": state_dict = OrderedDict({"model." + k: v for k, v in state_dict.items()}) else: state_dict = OrderedDict( diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index d24b723aaaf..91d4e85729d 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -85,7 +85,7 @@ def test_no_logger(self) -> None: @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: return { - "classification_model": "resnet18", + "model": "resnet18", "in_channels": 1, "loss": "ce", "num_classes": 2, @@ -101,7 +101,7 @@ def test_invalid_pretrained( self, model_kwargs: Dict[Any, Any], checkpoint: str ) -> None: model_kwargs["weights"] = checkpoint - model_kwargs["classification_model"] = "resnet50" + model_kwargs["model"] = "resnet50" match = "Trying to load resnet18 weights into a resnet50" with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) @@ -113,7 +113,7 @@ def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: ClassificationTask(**model_kwargs) def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["classification_model"] = "invalid_model" + model_kwargs["model"] = "invalid_model" match = "Model type 'invalid_model' is not a valid timm model." with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) @@ -178,7 +178,7 @@ def test_no_logger(self) -> None: @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: return { - "classification_model": "resnet18", + "model": "resnet18", "in_channels": 1, "loss": "ce", "num_classes": 1, diff --git a/tests/trainers/test_utils.py b/tests/trainers/test_utils.py index 33ede32dc5a..831e8f1a1d4 100644 --- a/tests/trainers/test_utils.py +++ b/tests/trainers/test_utils.py @@ -56,10 +56,7 @@ def test_extract_encoder_unsupported_model(tmp_path: Path) -> None: checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}} path = os.path.join(str(tmp_path), "dummy.ckpt") torch.save(checkpoint, path) - err = ( - "Unknown checkpoint task. Only encoder or classification_model" - " extraction is supported" - ) + err = "Unknown checkpoint task. Only encoder or model" " extraction is supported" with pytest.raises(ValueError, match=err): extract_encoder(path) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 154aadac566..2ef75fb7345 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -39,7 +39,7 @@ class ClassificationTask(pl.LightningModule): def config_model(self) -> None: """Configures the model based on kwargs parameters passed to the constructor.""" in_channels = self.hyperparams["in_channels"] - classification_model = self.hyperparams["classification_model"] + model = self.hyperparams["model"] imagenet_pretrained = False custom_pretrained = False @@ -58,25 +58,23 @@ def config_model(self) -> None: # Create the model valid_models = timm.list_models(pretrained=True) - if classification_model in valid_models: + if model in valid_models: self.model = timm.create_model( - classification_model, + model, num_classes=self.hyperparams["num_classes"], in_chans=in_channels, pretrained=imagenet_pretrained, ) else: - raise ValueError( - f"Model type '{classification_model}' is not a valid timm model." - ) + raise ValueError(f"Model type '{model}' is not a valid timm model.") if custom_pretrained: name, state_dict = utils.extract_encoder(self.hyperparams["weights"]) - if self.hyperparams["classification_model"] != name: + if self.hyperparams["model"] != name: raise ValueError( f"Trying to load {name} weights into a " - f"{self.hyperparams['classification_model']}" + f"{self.hyperparams['model']}" ) self.model = utils.load_state_dict(self.model, state_dict) @@ -97,10 +95,13 @@ def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. Keyword Args: - classification_model: Name of the classification model use + model: Name of the classification model use loss: Name of the loss function weights: Either "random", "imagenet_only", "imagenet_and_random", or "random_rgb" + + .. versionchanged:: 0.4 + The *classification_model* parameter was renamed to *model*. """ super().__init__() @@ -299,10 +300,13 @@ def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. Keyword Args: - classification_model: Name of the classification model use + model: Name of the classification model use loss: Name of the loss function weights: Either "random", "imagenet_only", "imagenet_and_random", or "random_rgb" + + .. versionchanged:: 0.4 + The *classification_model* parameter was renamed to *model*. """ super().__init__(**kwargs) diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index d0f0bd2b029..cc416913287 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -28,15 +28,14 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: tuple containing model name and state dict Raises: - ValueError: if 'classification_model' or 'encoder' not in + ValueError: if 'model' or 'encoder' not in checkpoint['hyper_parameters'] """ checkpoint = torch.load( # type: ignore[no-untyped-call] path, map_location=torch.device("cpu") ) - - if "classification_model" in checkpoint["hyper_parameters"]: - name = checkpoint["hyper_parameters"]["classification_model"] + if "model" in checkpoint["hyper_parameters"]: + name = checkpoint["hyper_parameters"]["model"] state_dict = checkpoint["state_dict"] state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k}) state_dict = OrderedDict( @@ -53,8 +52,7 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: ) else: raise ValueError( - "Unknown checkpoint task. Only encoder or classification_model" - " extraction is supported" + "Unknown checkpoint task. Only encoder or model" " extraction is supported" ) return name, state_dict From 0c4f904cfb9d196b2edf3b8ae1bc03b2a2d87dc1 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sun, 27 Nov 2022 14:35:44 +0100 Subject: [PATCH 2/7] fix failing test --- tests/trainers/test_classification.py | 8 ++++---- torchgeo/trainers/utils.py | 4 +--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 3b88a8ab57f..7059e59cd63 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -86,7 +86,7 @@ def test_no_logger(self) -> None: def model_kwargs(self) -> Dict[Any, Any]: return { "model": "resnet18", - "in_channels": 1, + "in_channels": 13, "loss": "ce", "num_classes": 10, "weights": "random", @@ -190,9 +190,9 @@ def test_no_logger(self) -> None: def model_kwargs(self) -> Dict[Any, Any]: return { "model": "resnet18", - "in_channels": 1, - "loss": "ce", - "num_classes": 1, + "in_channels": 14, + "loss": "bce", + "num_classes": 19, "weights": "random", } diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index cc416913287..8a0b957a4b5 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -31,9 +31,7 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: ValueError: if 'model' or 'encoder' not in checkpoint['hyper_parameters'] """ - checkpoint = torch.load( # type: ignore[no-untyped-call] - path, map_location=torch.device("cpu") - ) + checkpoint = torch.load(path, map_location=torch.device("cpu")) if "model" in checkpoint["hyper_parameters"]: name = checkpoint["hyper_parameters"]["model"] state_dict = checkpoint["state_dict"] From 867520f51d4855642d311eae0270c5a7bb56d6a8 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sun, 27 Nov 2022 18:31:24 +0100 Subject: [PATCH 3/7] expose all available timm models --- torchgeo/trainers/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 2e5c62e2d3d..a7726d3b5f5 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -68,7 +68,7 @@ def config_model(self) -> None: custom_pretrained = True # Create the model - valid_models = timm.list_models(pretrained=True) + valid_models = timm.list_models(pretrained=False) if model in valid_models: self.model = timm.create_model( model, From 45e0b64144733bb21111cc378e4c1aaba916c1e2 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sun, 27 Nov 2022 19:02:48 +0100 Subject: [PATCH 4/7] chmod --- evaluate.py | 16 ++++++++-------- torchgeo/trainers/classification.py | 18 ++++++++++++------ torchgeo/trainers/utils.py | 2 +- 3 files changed, 21 insertions(+), 15 deletions(-) mode change 100644 => 100755 evaluate.py diff --git a/evaluate.py b/evaluate.py old mode 100644 new mode 100755 index 6880ad81e5e..de6767cfa0b --- a/evaluate.py +++ b/evaluate.py @@ -163,18 +163,18 @@ def main(args: argparse.Namespace) -> None: if issubclass(TASK, ClassificationTask): val_row = { "split": "val", - "model": model.hparams["model"], - "learning_rate": model.hparams["learning_rate"], - "weights": model.hparams["weights"], - "loss": model.hparams["loss"], + "model": hparams["model"], + "learning_rate": hparams["learning_rate"], + "weights": hparams["weights"], + "loss": hparams["loss"], } test_row = { "split": "test", - "model": model.hparams["model"], - "learning_rate": model.hparams["learning_rate"], - "weights": model.hparams["weights"], - "loss": model.hparams["loss"], + "model": hparams["model"], + "learning_rate": hparams["learning_rate"], + "weights": hparams["weights"], + "loss": hparams["loss"], } elif issubclass(TASK, SemanticSegmentationTask): val_row = { diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index a7726d3b5f5..be6dddf0968 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -107,9 +107,12 @@ def __init__(self, **kwargs: Any) -> None: Keyword Args: model: Name of the classification model use - loss: Name of the loss function - weights: Either "random", "imagenet_only", "imagenet_and_random", or - "random_rgb" + loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal' + weights: Either "random" or "imagenet" + num_classes: Number of prediction classes + in_channels: Number of input channels to model + learning_rate: Learning rate for optimizer + learning_rate_schedule_patience: Patience for learning rate scheduler .. versionchanged:: 0.4 The *classification_model* parameter was renamed to *model*. @@ -312,9 +315,12 @@ def __init__(self, **kwargs: Any) -> None: Keyword Args: model: Name of the classification model use - loss: Name of the loss function - weights: Either "random", "imagenet_only", "imagenet_and_random", or - "random_rgb" + loss: Name of the loss function, currently only supports 'bce' + weights: Either "random" or 'imagenet' + num_classes: Number of prediction classes + in_channels: Number of input channels to model + learning_rate: Learning rate for optimizer + learning_rate_schedule_patience: Patience for learning rate scheduler .. versionchanged:: 0.4 The *classification_model* parameter was renamed to *model*. diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index 8a0b957a4b5..9f1b48a8868 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -50,7 +50,7 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: ) else: raise ValueError( - "Unknown checkpoint task. Only encoder or model" " extraction is supported" + "Unknown checkpoint task. Only encoder or model extraction is supported" ) return name, state_dict From 8a768454db778c7c7caf2d35cd1153352a438433 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sun, 27 Nov 2022 19:05:16 +0100 Subject: [PATCH 5/7] imagenet pretrained flag --- torchgeo/trainers/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index be6dddf0968..a1f934b3173 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -68,7 +68,7 @@ def config_model(self) -> None: custom_pretrained = True # Create the model - valid_models = timm.list_models(pretrained=False) + valid_models = timm.list_models(pretrained=imagenet_pretrained) if model in valid_models: self.model = timm.create_model( model, From d58beb1ca449d027d36014f60317f52102986e23 Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Sun, 27 Nov 2022 20:27:56 +0100 Subject: [PATCH 6/7] remove extra --- tests/trainers/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainers/test_utils.py b/tests/trainers/test_utils.py index 831e8f1a1d4..3a6dc0bc27d 100644 --- a/tests/trainers/test_utils.py +++ b/tests/trainers/test_utils.py @@ -56,7 +56,7 @@ def test_extract_encoder_unsupported_model(tmp_path: Path) -> None: checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}} path = os.path.join(str(tmp_path), "dummy.ckpt") torch.save(checkpoint, path) - err = "Unknown checkpoint task. Only encoder or model" " extraction is supported" + err = "Unknown checkpoint task. Only encoder or model extraction is supported" with pytest.raises(ValueError, match=err): extract_encoder(path) From 1b2d0154a2ae21513ce96fa4bf91416cf1ed643f Mon Sep 17 00:00:00 2001 From: Nils Lehmann Date: Mon, 28 Nov 2022 08:38:42 +0100 Subject: [PATCH 7/7] docstring list_models --- torchgeo/trainers/classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index a1f934b3173..675595c767b 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -38,13 +38,13 @@ class ClassificationTask(pl.LightningModule): Supports any available `Timm model `_ - as an architecture choice. To see a list of available pretrained + as an architecture choice. To see a list of available models, you can do: .. code-block:: python import timm - print(timm.list_models(pretrained=True)) + print(timm.list_models()) """ def config_model(self) -> None: