From 31b422ea4f262cc8c8996dc57dfa7eb87a5b6978 Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Mon, 6 Nov 2023 14:31:45 +0100 Subject: [PATCH] Update README with paragraph about available pretrained weights (#1716) * paragraph about pretrained models * paragraph about pretrained models * suggested changes * Update README.md Co-authored-by: Adam J. Stewart --------- Co-authored-by: Adam J. Stewart --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 6424dbbf47b..d54190f5b26 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,21 @@ for batch in dataloader: All TorchGeo datasets are compatible with PyTorch data loaders, making them easy to integrate into existing training workflows. The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch `Tensor`. +### Pre-trained Weights + +Pre-trained weights have proven to be tremendously beneficial for transfer learning tasks in computer vision. Practitioners usually utilize models pre-trained on the ImageNet dataset, containing RGB images. However, remote sensing data often goes beyond RGB with additional multispectral channels that can vary across sensors. TorchGeo is the first library to support models pre-trained on different multispectral sensors, and adopts torchvision's [multi-weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). A summary of currently available weights can be seen in the [docs](https://torchgeo.readthedocs.io/en/stable/api/models.html#pretrained-weights). To create a [timm](https://github.com/huggingface/pytorch-image-models) Resnet-18 model with weights that have been pretrained on Sentinel-2 imagery, you can do the following: + +```python +import timm +from torchgeo.models import ResNet18_Weights + +weights = ResNet18_Weights.SENTINEL2_ALL_MOCO +model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10) +model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False) +``` + +These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html). + ### Reproducibility with Lightning In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created Lightning [*datamodules*](https://torchgeo.readthedocs.io/en/stable/api/datamodules.html) with well-defined train-val-test splits and [*trainers*](https://torchgeo.readthedocs.io/en/stable/api/trainers.html) for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a semantic segmentation model on the [Inria Aerial Image Labeling](https://project.inria.fr/aerialimagelabeling/) dataset is as easy as a few imports and four lines of code.