diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml index de376238bd8..81d0e83893f 100644 --- a/conf/bigearthnet.yaml +++ b/conf/bigearthnet.yaml @@ -7,7 +7,7 @@ experiment: task: "bigearthnet" module: loss: "bce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/conf/eurosat.yaml b/conf/eurosat.yaml index 4f046aa4847..5abde6d4592 100644 --- a/conf/eurosat.yaml +++ b/conf/eurosat.yaml @@ -2,7 +2,7 @@ experiment: task: "eurosat" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/conf/resisc45.yaml b/conf/resisc45.yaml index f8c70e9961a..4dc34b13c0a 100644 --- a/conf/resisc45.yaml +++ b/conf/resisc45.yaml @@ -7,7 +7,7 @@ experiment: task: "resisc45" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/conf/so2sat.yaml b/conf/so2sat.yaml index c54259004c1..4caf2a01ba5 100644 --- a/conf/so2sat.yaml +++ b/conf/so2sat.yaml @@ -7,7 +7,7 @@ experiment: task: "so2sat" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/conf/ucmerced.yaml b/conf/ucmerced.yaml index 9975a7c4821..4ab6612d1ae 100644 --- a/conf/ucmerced.yaml +++ b/conf/ucmerced.yaml @@ -2,7 +2,7 @@ experiment: task: "ucmerced" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" weights: null learning_rate: 1e-3 learning_rate_schedule_patience: 6 diff --git a/evaluate.py b/evaluate.py index 29220b6d4cc..de6767cfa0b 100755 --- a/evaluate.py +++ b/evaluate.py @@ -163,7 +163,7 @@ def main(args: argparse.Namespace) -> None: if issubclass(TASK, ClassificationTask): val_row = { "split": "val", - "classification_model": hparams["classification_model"], + "model": hparams["model"], "learning_rate": hparams["learning_rate"], "weights": hparams["weights"], "loss": hparams["loss"], @@ -171,7 +171,7 @@ def main(args: argparse.Namespace) -> None: test_row = { "split": "test", - "classification_model": hparams["classification_model"], + "model": hparams["model"], "learning_rate": hparams["learning_rate"], "weights": hparams["weights"], "loss": hparams["loss"], diff --git a/experiments/run_resisc45_experiments.py b/experiments/run_resisc45_experiments.py index c149a630532..bd7f79a013a 100755 --- a/experiments/run_resisc45_experiments.py +++ b/experiments/run_resisc45_experiments.py @@ -50,7 +50,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: "python train.py" + f" config_file={config_file}" + f" experiment.name={experiment_name}" - + f" experiment.module.classification_model={model}" + + f" experiment.module.model={model}" + f" experiment.module.learning_rate={lr}" + f" experiment.module.loss={loss}" + f" experiment.module.weights={weights}" diff --git a/experiments/run_so2sat_byol_experiments.py b/experiments/run_so2sat_byol_experiments.py index fe72b59a700..e0b132f7d9c 100755 --- a/experiments/run_so2sat_byol_experiments.py +++ b/experiments/run_so2sat_byol_experiments.py @@ -52,7 +52,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: "python train.py" + f" config_file={config_file}" + f" experiment.name={experiment_name}" - + f" experiment.module.classification_model={model}" + + f" experiment.module.model={model}" + f" experiment.module.learning_rate={lr}" + f" experiment.module.loss={loss}" + f" experiment.module.weights={weights}" diff --git a/experiments/run_so2sat_experiments.py b/experiments/run_so2sat_experiments.py index bdc7e5a7a30..6d1ea0f8eae 100755 --- a/experiments/run_so2sat_experiments.py +++ b/experiments/run_so2sat_experiments.py @@ -50,7 +50,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: "python train.py" + f" config_file={config_file}" + f" experiment.name={experiment_name}" - + f" experiment.module.classification_model={model}" + + f" experiment.module.model={model}" + f" experiment.module.learning_rate={lr}" + f" experiment.module.loss={loss}" + f" experiment.module.weights={weights}" diff --git a/experiments/run_so2sat_seed_experiments.py b/experiments/run_so2sat_seed_experiments.py index 5585e07a1d3..4f861926801 100755 --- a/experiments/run_so2sat_seed_experiments.py +++ b/experiments/run_so2sat_seed_experiments.py @@ -51,7 +51,7 @@ def do_work(work: "Queue[str]", gpu_idx: int) -> bool: "python train.py" + f" config_file={config_file}" + f" experiment.name={experiment_name}" - + f" experiment.module.classification_model={model}" + + f" experiment.module.model={model}" + f" experiment.module.learning_rate={lr}" + f" experiment.module.loss={loss}" + f" experiment.module.weights={weights}" diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml index e6534c888a5..ed9c68d39a0 100644 --- a/tests/conf/bigearthnet_all.yaml +++ b/tests/conf/bigearthnet_all.yaml @@ -2,7 +2,7 @@ experiment: task: "bigearthnet" module: loss: "bce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml index 720dd95109a..6c16bb4e7e0 100644 --- a/tests/conf/bigearthnet_s1.yaml +++ b/tests/conf/bigearthnet_s1.yaml @@ -2,7 +2,7 @@ experiment: task: "bigearthnet" module: loss: "bce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml index 78ea30eb91a..74876350e8f 100644 --- a/tests/conf/bigearthnet_s2.yaml +++ b/tests/conf/bigearthnet_s2.yaml @@ -2,7 +2,7 @@ experiment: task: "bigearthnet" module: loss: "bce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/eurosat.yaml b/tests/conf/eurosat.yaml index 96cb403e962..e865c7af8e9 100644 --- a/tests/conf/eurosat.yaml +++ b/tests/conf/eurosat.yaml @@ -2,7 +2,7 @@ experiment: task: "eurosat" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/resisc45.yaml b/tests/conf/resisc45.yaml index 11482c1b677..89f7b8072c4 100644 --- a/tests/conf/resisc45.yaml +++ b/tests/conf/resisc45.yaml @@ -2,7 +2,7 @@ experiment: task: "resisc45" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/so2sat_supervised.yaml b/tests/conf/so2sat_supervised.yaml index f85801fa81a..476644ffe59 100644 --- a/tests/conf/so2sat_supervised.yaml +++ b/tests/conf/so2sat_supervised.yaml @@ -2,7 +2,7 @@ experiment: task: "so2sat" module: loss: "focal" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/so2sat_unsupervised.yaml b/tests/conf/so2sat_unsupervised.yaml index 7f6854dc8f1..e7aeda2547d 100644 --- a/tests/conf/so2sat_unsupervised.yaml +++ b/tests/conf/so2sat_unsupervised.yaml @@ -2,7 +2,7 @@ experiment: task: "so2sat" module: loss: "jaccard" - classification_model: "resnet18" + model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: "random" diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index 892e049aa02..fe39f579c35 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -2,7 +2,7 @@ experiment: task: "ucmerced" module: loss: "ce" - classification_model: "resnet18" + model: "resnet18" weights: "random" learning_rate: 1e-3 learning_rate_schedule_patience: 6 diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index ad708afa246..792703cdaab 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -31,11 +31,11 @@ def state_dict(model: Module) -> Dict[str, Tensor]: return model.state_dict() -@pytest.fixture(params=["classification_model", "encoder_name"]) +@pytest.fixture(params=["model", "encoder_name"]) def checkpoint( state_dict: Dict[str, Tensor], request: SubRequest, tmp_path: Path ) -> str: - if request.param == "classification_model": + if request.param == "model": state_dict = OrderedDict({"model." + k: v for k, v in state_dict.items()}) else: state_dict = OrderedDict( diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index fa5fdb7003f..7059e59cd63 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -85,7 +85,7 @@ def test_no_logger(self) -> None: @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: return { - "classification_model": "resnet18", + "model": "resnet18", "in_channels": 13, "loss": "ce", "num_classes": 10, @@ -101,7 +101,7 @@ def test_invalid_pretrained( self, model_kwargs: Dict[Any, Any], checkpoint: str ) -> None: model_kwargs["weights"] = checkpoint - model_kwargs["classification_model"] = "resnet50" + model_kwargs["model"] = "resnet50" match = "Trying to load resnet18 weights into a resnet50" with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) @@ -113,7 +113,7 @@ def test_invalid_loss(self, model_kwargs: Dict[Any, Any]) -> None: ClassificationTask(**model_kwargs) def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["classification_model"] = "invalid_model" + model_kwargs["model"] = "invalid_model" match = "Model type 'invalid_model' is not a valid timm model." with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) @@ -189,7 +189,7 @@ def test_no_logger(self) -> None: @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: return { - "classification_model": "resnet18", + "model": "resnet18", "in_channels": 14, "loss": "bce", "num_classes": 19, diff --git a/tests/trainers/test_utils.py b/tests/trainers/test_utils.py index 33ede32dc5a..3a6dc0bc27d 100644 --- a/tests/trainers/test_utils.py +++ b/tests/trainers/test_utils.py @@ -56,10 +56,7 @@ def test_extract_encoder_unsupported_model(tmp_path: Path) -> None: checkpoint = {"hyper_parameters": {"some_unsupported_model": "resnet18"}} path = os.path.join(str(tmp_path), "dummy.ckpt") torch.save(checkpoint, path) - err = ( - "Unknown checkpoint task. Only encoder or classification_model" - " extraction is supported" - ) + err = "Unknown checkpoint task. Only encoder or model extraction is supported" with pytest.raises(ValueError, match=err): extract_encoder(path) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 4494309379d..675595c767b 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -38,19 +38,19 @@ class ClassificationTask(pl.LightningModule): Supports any available `Timm model `_ - as an architecture choice. To see a list of available pretrained + as an architecture choice. To see a list of available models, you can do: .. code-block:: python import timm - print(timm.list_models(pretrained=True)) + print(timm.list_models()) """ def config_model(self) -> None: """Configures the model based on kwargs parameters passed to the constructor.""" in_channels = self.hyperparams["in_channels"] - classification_model = self.hyperparams["classification_model"] + model = self.hyperparams["model"] imagenet_pretrained = False custom_pretrained = False @@ -68,26 +68,24 @@ def config_model(self) -> None: custom_pretrained = True # Create the model - valid_models = timm.list_models(pretrained=True) - if classification_model in valid_models: + valid_models = timm.list_models(pretrained=imagenet_pretrained) + if model in valid_models: self.model = timm.create_model( - classification_model, + model, num_classes=self.hyperparams["num_classes"], in_chans=in_channels, pretrained=imagenet_pretrained, ) else: - raise ValueError( - f"Model type '{classification_model}' is not a valid timm model." - ) + raise ValueError(f"Model type '{model}' is not a valid timm model.") if custom_pretrained: name, state_dict = utils.extract_encoder(self.hyperparams["weights"]) - if self.hyperparams["classification_model"] != name: + if self.hyperparams["model"] != name: raise ValueError( f"Trying to load {name} weights into a " - f"{self.hyperparams['classification_model']}" + f"{self.hyperparams['model']}" ) self.model = utils.load_state_dict(self.model, state_dict) @@ -108,13 +106,16 @@ def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. Keyword Args: - classification_model: Name of the classification model use - loss: Name of the loss function + model: Name of the classification model use + loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal' weights: Either "random" or "imagenet" num_classes: Number of prediction classes in_channels: Number of input channels to model learning_rate: Learning rate for optimizer learning_rate_schedule_patience: Patience for learning rate scheduler + + .. versionchanged:: 0.4 + The *classification_model* parameter was renamed to *model*. """ super().__init__() @@ -313,13 +314,16 @@ def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. Keyword Args: - classification_model: Name of the classification model use + model: Name of the classification model use loss: Name of the loss function, currently only supports 'bce' weights: Either "random" or 'imagenet' num_classes: Number of prediction classes in_channels: Number of input channels to model learning_rate: Learning rate for optimizer learning_rate_schedule_patience: Patience for learning rate scheduler + + .. versionchanged:: 0.4 + The *classification_model* parameter was renamed to *model*. """ super().__init__(**kwargs) diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index bebdafd2715..9f1b48a8868 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -28,13 +28,12 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: tuple containing model name and state dict Raises: - ValueError: if 'classification_model' or 'encoder' not in + ValueError: if 'model' or 'encoder' not in checkpoint['hyper_parameters'] """ checkpoint = torch.load(path, map_location=torch.device("cpu")) - - if "classification_model" in checkpoint["hyper_parameters"]: - name = checkpoint["hyper_parameters"]["classification_model"] + if "model" in checkpoint["hyper_parameters"]: + name = checkpoint["hyper_parameters"]["model"] state_dict = checkpoint["state_dict"] state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k}) state_dict = OrderedDict( @@ -51,8 +50,7 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: ) else: raise ValueError( - "Unknown checkpoint task. Only encoder or classification_model" - " extraction is supported" + "Unknown checkpoint task. Only encoder or model extraction is supported" ) return name, state_dict