diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 1db0154e1d1..67a59b7ba80 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -3,7 +3,7 @@ """NWPU VHR-10 datamodule.""" -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Union import kornia.augmentation as K import torch @@ -24,7 +24,7 @@ class _AugPipe(Module): """Pipeline for applying augmentations sequentially on select data keys.""" def __init__( - self, augs: Callable[[Dict[str, Any]], Dict[str, Any]], batch_size: int + self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int ) -> None: """Initialize a new _AugPipe instance. @@ -36,7 +36,7 @@ def __init__( self.augs = augs self.batch_size = batch_size - def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Apply the augmentation. Args: @@ -67,7 +67,7 @@ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: return batch -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: +def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]: """Custom object detection collate fn to handle variable boxes. Args: @@ -76,7 +76,7 @@ def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: Returns: batch dict output """ - output: Dict[str, Any] = {} + output: dict[str, Any] = {} output["image"] = [sample["image"] for sample in batch] output["boxes"] = [sample["boxes"] for sample in batch] output["labels"] = [sample["labels"] for sample in batch] @@ -93,7 +93,7 @@ class VHR10DataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, - patch_size: Union[Tuple[int, int], int] = 512, + patch_size: Union[tuple[int, int], int] = 512, num_workers: int = 0, val_split_pct: float = 0.2, test_split_pct: float = 0.2,