From 77940a137d7b4095927310502440e6cc682af79e Mon Sep 17 00:00:00 2001 From: Robin Cole Date: Fri, 6 Oct 2023 15:04:56 +0100 Subject: [PATCH] Update info on LightningCLI (#1628) * Add info * Address comments --- README.md | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c33f4a55034..6424dbbf47b 100644 --- a/README.md +++ b/README.md @@ -146,12 +146,66 @@ trainer.fit(model=task, datamodule=datamodule) Building segmentations produced by a U-Net model trained on the Inria Aerial Image Labeling dataset -In our GitHub repo, we provide `train.py` and `evaluate.py` scripts to train and evaluate the performance of models using these datamodules and trainers. These scripts are configurable via the command line and/or via YAML configuration files. See the [conf](https://github.com/microsoft/torchgeo/blob/main/conf) directory for example configuration files that can be customized for different training runs. +TorchGeo also supports command-line interface training using [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html). It can be invoked in two ways: ```console -$ python train.py config_file=conf/landcoverai.yaml +# If torchgeo has been installed +torchgeo +# If torchgeo has been installed, or if it has been cloned to the current directory +python3 -m torchgeo ``` +It supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages: + +```console +# See valid stages +torchgeo --help +# See valid trainer options +torchgeo fit --help +# See valid model options +torchgeo fit --model.help ClassificationTask +# See valid data options +torchgeo fit --data.help EuroSAT100DataModule +``` + +Using the following config file: +```yaml +trainer: + max_epochs: 20 +model: + class_path: ClassificationTask + init_args: + model: "resnet18" + in_channels: 13 + num_classes: 10 +data: + class_path: EuroSAT100DataModule + init_args: + batch_size: 8 + dict_kwargs: + download: true +``` + +we can see the script in action: +```console +# Train and validate a model +torchgeo fit --config config.yaml +# Validate-only +torchgeo validate --config config.yaml +# Calculate and report test accuracy +torchgeo test --config config.yaml --trainer.ckpt_path=... +``` + +It can also be imported and used in a Python script if you need to extend it to add new features: + +```python +from torchgeo.main import main + +main(["fit", "--config", "config.yaml"]) +``` + +See the [Lightning documentation](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html) for more details. + ## Citation If you use this software in your work, please cite our [paper](https://dl.acm.org/doi/10.1145/3557915.3560953):