Skip to content

Commit

Permalink
Add reflections to the tile-borders
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Oct 23, 2024
1 parent c491bfe commit e77b7ea
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 23 deletions.
1 change: 1 addition & 0 deletions config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ model-dir = "models"
patch-size = 1024
overlap = 128
batch-size = 2
reflection = 32
47 changes: 28 additions & 19 deletions darts-segmentation/src/darts_segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def tile2tensor(self, tile: xr.Dataset) -> torch.Tensor:
Respects the input combination from the config.
Returns:
A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.
A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.
"""
bands = []
Expand All @@ -92,7 +92,7 @@ def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor:
Respects the the input combination from the config.
Returns:
A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.
A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.
"""
bands = []
Expand All @@ -107,19 +107,20 @@ def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor:
return torch.stack(bands, dim=0).reshape(len(tiles), len(self.config["input_combination"]), *bands[0].shape)

def segment_tile(
self, tile: xr.Dataset, patch_size: int = 1024, overlap: int = 16, batch_size: int = 8
self, tile: xr.Dataset, patch_size: int = 1024, overlap: int = 16, batch_size: int = 8, reflection: int = 0
) -> xr.Dataset:
"""Run inference on a tile.
Args:
tile: The input tile, containing preprocessed, harmonized data.
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
tile: The input tile, containing preprocessed, harmonized data.
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
Returns:
Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1].
Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1].
"""
# Convert the tile to a tensor
Expand All @@ -129,7 +130,7 @@ def segment_tile(
tensor_tile = tensor_tile.unsqueeze(0)

probabilities = predict_in_patches(
self.model, tensor_tile, patch_size, overlap, batch_size, self.device
self.model, tensor_tile, patch_size, overlap, batch_size, reflection, self.device
).squeeze(0)

# Highly sophisticated DL-based predictor
Expand All @@ -139,19 +140,25 @@ def segment_tile(
return tile

def segment_tile_batched(
self, tiles: list[xr.Dataset], patch_size: int = 1024, overlap: int = 16, batch_size: int = 8
self,
tiles: list[xr.Dataset],
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
) -> list[xr.Dataset]:
"""Run inference on a list of tiles.
Args:
tiles: The input tiles, containing preprocessed, harmonized data.
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
tiles: The input tiles, containing preprocessed, harmonized data.
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
Returns:
A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1].
A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1].
"""
# Convert the tiles to tensors
Expand All @@ -162,7 +169,9 @@ def segment_tile_batched(
# Create a batch dimension, because predict expects it
tensor_tiles = torch.stack(tensor_tiles, dim=0)

probabilities = predict_in_patches(self.model, tensor_tiles, patch_size, overlap, batch_size, self.device)
probabilities = predict_in_patches(
self.model, tensor_tiles, patch_size, overlap, batch_size, reflection, self.device
)

# Highly sophisticated DL-based predictor
for tile, probs in zip(tiles, probabilities):
Expand All @@ -175,11 +184,11 @@ def __call__(self, input: xr.Dataset | list[xr.Dataset]) -> xr.Dataset | list[xr
"""Run inference on a single tile or a list of tiles.
Args:
input: A single tile or a list of tiles.
input: A single tile or a list of tiles.
Returns:
A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input.
Each `probability` has type float32 and range [0, 1].
A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input.
Each `probability` has type float32 and range [0, 1].
Raises:
ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset
Expand Down
9 changes: 6 additions & 3 deletions darts-segmentation/src/darts_segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def predict_in_patches(
patch_size: int,
overlap: int,
batch_size: int,
reflection: int,
device=torch.device,
return_weights: bool = False,
) -> torch.Tensor:
Expand All @@ -108,6 +109,7 @@ def predict_in_patches(
overlap (int): The size of the overlap.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Tensor will be sliced into patches and these again will be infered in batches.
reflection (int): Reflection-Padding which will be applied to the edges of the tensor.
device (torch.device): The device to use for the prediction.
return_weights (bool, optional): Whether to return the weights. Can be used for debugging. Defaults to False.
Expand All @@ -121,8 +123,9 @@ def predict_in_patches(
f"with patch_size {patch_size}, overlap {overlap} and batch_size {batch_size} on device {device}"
)
assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
# Add a 1px border to avoid pixel loss when applying the soft margin
tensor_tiles = torch.nn.functional.pad(tensor_tiles, (1, 1, 1, 1), mode="reflect")
# Add a 1px + reflection border to avoid pixel loss when applying the soft margin and to reduce edge-artefacts
p = 1 + reflection
tensor_tiles = torch.nn.functional.pad(tensor_tiles, (p, p, p, p), mode="reflect")
bs, c, h, w = tensor_tiles.shape
step_size = patch_size - overlap
nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)
Expand Down Expand Up @@ -173,7 +176,7 @@ def predict_in_patches(
prediction = prediction / weights

# Remove the 1px border and the padding
prediction = prediction[:, 1:-1, 1:-1]
prediction = prediction[:, p:-p, p:-p]
logger.debug(f"Predicting took {time.time() - start_time:.2f}s")

if return_weights:
Expand Down
6 changes: 5 additions & 1 deletion darts/src/darts/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def run_native_orthotile_pipeline(
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
):
"""Search for all PlanetScope scenes in the given directory and runs the segmentation pipeline on them.
Expand All @@ -20,6 +21,7 @@ def run_native_orthotile_pipeline(
patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
overlap (int, optional): The overlap to use for inference. Defaults to 16.
batch_size (int, optional): The batch size to use for inference. Defaults to 8.
reflection (int, optional): The reflection padding to use for inference. Defaults to 0.
Todo:
Document the structure of the input data dir.
Expand All @@ -41,7 +43,9 @@ def run_native_orthotile_pipeline(
tile = load_and_preprocess_planet_scene(fpath, elevation_path, slope_path)

model = SMPSegmenter(model_dir / "RTS_v6_notcvis.pt")
tile = model.segment_tile(tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size)
tile = model.segment_tile(
tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
)
tile = prepare_export(tile)

outpath.mkdir(parents=True, exist_ok=True)
Expand Down

0 comments on commit e77b7ea

Please sign in to comment.