Skip to content

Commit

Permalink
Update README with paragraph about available pretrained weights (#1716)
Browse files Browse the repository at this point in the history
* paragraph about pretrained models

* paragraph about pretrained models

* suggested changes

* Update README.md

Co-authored-by: Adam J. Stewart <[email protected]>

---------

Co-authored-by: Adam J. Stewart <[email protected]>
  • Loading branch information
nilsleh and adamjstewart committed Nov 6, 2023
1 parent 7a91236 commit 31b422e
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ for batch in dataloader:

All TorchGeo datasets are compatible with PyTorch data loaders, making them easy to integrate into existing training workflows. The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch `Tensor`.

### Pre-trained Weights

Pre-trained weights have proven to be tremendously beneficial for transfer learning tasks in computer vision. Practitioners usually utilize models pre-trained on the ImageNet dataset, containing RGB images. However, remote sensing data often goes beyond RGB with additional multispectral channels that can vary across sensors. TorchGeo is the first library to support models pre-trained on different multispectral sensors, and adopts torchvision's [multi-weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). A summary of currently available weights can be seen in the [docs](https://torchgeo.readthedocs.io/en/stable/api/models.html#pretrained-weights). To create a [timm](https://github.com/huggingface/pytorch-image-models) Resnet-18 model with weights that have been pretrained on Sentinel-2 imagery, you can do the following:

```python
import timm
from torchgeo.models import ResNet18_Weights

weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10)
model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
```

These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html).

### Reproducibility with Lightning

In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created Lightning [*datamodules*](https://torchgeo.readthedocs.io/en/stable/api/datamodules.html) with well-defined train-val-test splits and [*trainers*](https://torchgeo.readthedocs.io/en/stable/api/trainers.html) for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a semantic segmentation model on the [Inria Aerial Image Labeling](https://project.inria.fr/aerialimagelabeling/) dataset is as easy as a few imports and four lines of code.
Expand Down

0 comments on commit 31b422e

Please sign in to comment.