diff --git a/config.toml b/config.toml index a9f3721..9426a9c 100644 --- a/config.toml +++ b/config.toml @@ -8,3 +8,4 @@ model-dir = "models" patch-size = 1024 overlap = 128 batch-size = 2 +reflection = 32 diff --git a/darts-segmentation/src/darts_segmentation/segment.py b/darts-segmentation/src/darts_segmentation/segment.py index 8820057..ded5ab7 100644 --- a/darts-segmentation/src/darts_segmentation/segment.py +++ b/darts-segmentation/src/darts_segmentation/segment.py @@ -70,7 +70,7 @@ def tile2tensor(self, tile: xr.Dataset) -> torch.Tensor: Respects the input combination from the config. Returns: - A torch tensor for the full tile consisting of the bands specified in `self.band_combination`. + A torch tensor for the full tile consisting of the bands specified in `self.band_combination`. """ bands = [] @@ -92,7 +92,7 @@ def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor: Respects the the input combination from the config. Returns: - A torch tensor for the full tile consisting of the bands specified in `self.band_combination`. + A torch tensor for the full tile consisting of the bands specified in `self.band_combination`. """ bands = [] @@ -107,19 +107,20 @@ def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor: return torch.stack(bands, dim=0).reshape(len(tiles), len(self.config["input_combination"]), *bands[0].shape) def segment_tile( - self, tile: xr.Dataset, patch_size: int = 1024, overlap: int = 16, batch_size: int = 8 + self, tile: xr.Dataset, patch_size: int = 1024, overlap: int = 16, batch_size: int = 8, reflection: int = 0 ) -> xr.Dataset: """Run inference on a tile. Args: - tile: The input tile, containing preprocessed, harmonized data. - patch_size (int): The size of the patches. Defaults to 1024. - overlap (int): The size of the overlap. Defaults to 16. - batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles. + tile: The input tile, containing preprocessed, harmonized data. + patch_size (int): The size of the patches. Defaults to 1024. + overlap (int): The size of the overlap. Defaults to 16. + batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8. + reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0. Returns: - Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1]. + Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1]. """ # Convert the tile to a tensor @@ -129,7 +130,7 @@ def segment_tile( tensor_tile = tensor_tile.unsqueeze(0) probabilities = predict_in_patches( - self.model, tensor_tile, patch_size, overlap, batch_size, self.device + self.model, tensor_tile, patch_size, overlap, batch_size, reflection, self.device ).squeeze(0) # Highly sophisticated DL-based predictor @@ -139,19 +140,25 @@ def segment_tile( return tile def segment_tile_batched( - self, tiles: list[xr.Dataset], patch_size: int = 1024, overlap: int = 16, batch_size: int = 8 + self, + tiles: list[xr.Dataset], + patch_size: int = 1024, + overlap: int = 16, + batch_size: int = 8, + reflection: int = 0, ) -> list[xr.Dataset]: """Run inference on a list of tiles. Args: - tiles: The input tiles, containing preprocessed, harmonized data. - patch_size (int): The size of the patches. Defaults to 1024. - overlap (int): The size of the overlap. Defaults to 16. - batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles. + tiles: The input tiles, containing preprocessed, harmonized data. + patch_size (int): The size of the patches. Defaults to 1024. + overlap (int): The size of the overlap. Defaults to 16. + batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8. + reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0. Returns: - A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1]. + A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1]. """ # Convert the tiles to tensors @@ -162,7 +169,9 @@ def segment_tile_batched( # Create a batch dimension, because predict expects it tensor_tiles = torch.stack(tensor_tiles, dim=0) - probabilities = predict_in_patches(self.model, tensor_tiles, patch_size, overlap, batch_size, self.device) + probabilities = predict_in_patches( + self.model, tensor_tiles, patch_size, overlap, batch_size, reflection, self.device + ) # Highly sophisticated DL-based predictor for tile, probs in zip(tiles, probabilities): @@ -175,11 +184,11 @@ def __call__(self, input: xr.Dataset | list[xr.Dataset]) -> xr.Dataset | list[xr """Run inference on a single tile or a list of tiles. Args: - input: A single tile or a list of tiles. + input: A single tile or a list of tiles. Returns: - A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input. - Each `probability` has type float32 and range [0, 1]. + A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input. + Each `probability` has type float32 and range [0, 1]. Raises: ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset diff --git a/darts-segmentation/src/darts_segmentation/utils.py b/darts-segmentation/src/darts_segmentation/utils.py index 51b238a..98b3877 100644 --- a/darts-segmentation/src/darts_segmentation/utils.py +++ b/darts-segmentation/src/darts_segmentation/utils.py @@ -96,6 +96,7 @@ def predict_in_patches( patch_size: int, overlap: int, batch_size: int, + reflection: int, device=torch.device, return_weights: bool = False, ) -> torch.Tensor: @@ -108,6 +109,7 @@ def predict_in_patches( overlap (int): The size of the overlap. batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. + reflection (int): Reflection-Padding which will be applied to the edges of the tensor. device (torch.device): The device to use for the prediction. return_weights (bool, optional): Whether to return the weights. Can be used for debugging. Defaults to False. @@ -121,8 +123,9 @@ def predict_in_patches( f"with patch_size {patch_size}, overlap {overlap} and batch_size {batch_size} on device {device}" ) assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}" - # Add a 1px border to avoid pixel loss when applying the soft margin - tensor_tiles = torch.nn.functional.pad(tensor_tiles, (1, 1, 1, 1), mode="reflect") + # Add a 1px + reflection border to avoid pixel loss when applying the soft margin and to reduce edge-artefacts + p = 1 + reflection + tensor_tiles = torch.nn.functional.pad(tensor_tiles, (p, p, p, p), mode="reflect") bs, c, h, w = tensor_tiles.shape step_size = patch_size - overlap nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size) @@ -173,7 +176,7 @@ def predict_in_patches( prediction = prediction / weights # Remove the 1px border and the padding - prediction = prediction[:, 1:-1, 1:-1] + prediction = prediction[:, p:-p, p:-p] logger.debug(f"Predicting took {time.time() - start_time:.2f}s") if return_weights: diff --git a/darts/src/darts/native.py b/darts/src/darts/native.py index 24fccb5..f8c85da 100644 --- a/darts/src/darts/native.py +++ b/darts/src/darts/native.py @@ -10,6 +10,7 @@ def run_native_orthotile_pipeline( patch_size: int = 1024, overlap: int = 16, batch_size: int = 8, + reflection: int = 0, ): """Search for all PlanetScope scenes in the given directory and runs the segmentation pipeline on them. @@ -20,6 +21,7 @@ def run_native_orthotile_pipeline( patch_size (int, optional): The patch size to use for inference. Defaults to 1024. overlap (int, optional): The overlap to use for inference. Defaults to 16. batch_size (int, optional): The batch size to use for inference. Defaults to 8. + reflection (int, optional): The reflection padding to use for inference. Defaults to 0. Todo: Document the structure of the input data dir. @@ -41,7 +43,9 @@ def run_native_orthotile_pipeline( tile = load_and_preprocess_planet_scene(fpath, elevation_path, slope_path) model = SMPSegmenter(model_dir / "RTS_v6_notcvis.pt") - tile = model.segment_tile(tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size) + tile = model.segment_tile( + tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection + ) tile = prepare_export(tile) outpath.mkdir(parents=True, exist_ok=True)