From 6f7ce55aedfb7103a7f1e07ef408f07ac787d698 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 26 Sep 2024 13:32:58 -0700 Subject: [PATCH 01/58] wind turbine: compare new split vs old split performance --- .../wind_turbine/assign_old_splits.py | 49 ++++++ .../wind_turbine/config.yaml | 8 +- .../wind_turbine/config_flip.yaml | 9 +- .../wind_turbine/config_flip_oldsplit.yaml | 164 ++++++++++++++++++ 4 files changed, 222 insertions(+), 8 deletions(-) create mode 100644 convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py create mode 100644 convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py b/convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py new file mode 100644 index 00000000..0e77f638 --- /dev/null +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py @@ -0,0 +1,49 @@ +"""Assign properties matching the old splits from multisat.""" + +import json +import multiprocessing +import sys + +import tqdm +from rslearn.dataset import Dataset +from upath import UPath + +in_fnames = { + "train": "/multisat/mosaic/splits/turbine_naip_supervision/train.json", + "val": "/multisat/mosaic/splits/turbine_naip_supervision/val.json", +} + + +def process(job): + window, split = job + if "old_split" in window.options and window.options["old_split"] == split: + return + window.options["old_split"] = split + window.save() + + +def assign_split(ds_root: str, workers: int = 32): + ds_path = UPath(ds_root) + dataset = Dataset(ds_path) + windows = dataset.load_windows(show_progress=True, workers=workers) + windows_by_name = {window.name: window for window in windows} + + jobs = [] + for split, fname in in_fnames.items(): + with open(fname) as f: + for col, row in json.load(f): + expected_window_name = f"{col*512}_{row*512}" + if expected_window_name not in windows_by_name: + continue + jobs.append((windows_by_name[expected_window_name], split)) + + p = multiprocessing.Pool(workers) + outputs = p.imap_unordered(process, jobs) + for _ in tqdm.tqdm(outputs, total=len(jobs), desc="Assign old split"): + pass + p.close() + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + assign_split(ds_root=sys.argv[1]) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml index d96ef035..8a887aaf 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml @@ -122,17 +122,17 @@ data: - class_path: rslp.transforms.mask.Mask groups: ["label", "naip"] tags: - split: train + old_split: train val_config: patch_size: 256 groups: ["label", "naip"] tags: - split: val + old_split: val test_config: patch_size: 256 groups: ["label", "naip"] tags: - split: val + old_split: val predict_config: groups: ["predict"] load_all_patches: true @@ -156,4 +156,4 @@ trainer: monitor: val_detect/mAP mode: max rslp_project: satlas_wind_turbine -rslp_experiment: debug_20240806_satlaspretrainold_patch256_noflip_satlasbands3000_3image_02 +rslp_experiment: debug_20240806_oldsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_00 diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml index 4cd6ffa4..a63a9829 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml @@ -27,10 +27,11 @@ model: num_classes: 2 anchor_sizes: [[32], [64], [128], [256]] lr: 0.0001 - plateau_factor: 0.1 - plateau_patience: 10 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 plateau_min_lr: 0 - plateau_cooldown: 0 + plateau_cooldown: 10 restore_config: restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth remap_prefixes: @@ -160,4 +161,4 @@ trainer: monitor: val_detect/mAP mode: max rslp_project: satlas_wind_turbine -rslp_experiment: debug_20240806_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_01 +rslp_experiment: debug_20240806_newsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_02 diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml new file mode 100644 index 00000000..f875da8c --- /dev/null +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml @@ -0,0 +1,164 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.0001 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: INT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: INT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: INT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 256 + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + old_split: train + val_config: + patch_size: 256 + groups: ["label", "naip"] + tags: + old_split: val + test_config: + patch_size: 256 + groups: ["label", "naip"] + tags: + old_split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max +rslp_project: satlas_wind_turbine +rslp_experiment: debug_20240806_oldsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_02 From c2ec853e1a3195b43d5b62ea1a4dbb2af2ab55e4 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 27 Sep 2024 09:17:10 -0700 Subject: [PATCH 02/58] try with 384x384 patches plus freeze the model for the first couple epochs --- .../wind_turbine/config_flip.yaml | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml index a63a9829..31e7799c 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml @@ -26,7 +26,7 @@ model: num_channels: 128 num_classes: 2 anchor_sizes: [[32], [64], [128], [256]] - lr: 0.0001 + lr: 0.00002 plateau: true plateau_factor: 0.2 plateau_patience: 2 @@ -105,9 +105,10 @@ data: init_args: mean: 0 std: 3000 + valid_range: [0, 3000] - class_path: rslp.transforms.mask.Mask train_config: - patch_size: 256 + patch_size: 384 transforms: - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: @@ -120,6 +121,7 @@ data: init_args: mean: 0 std: 3000 + valid_range: [0, 3000] - class_path: rslp.transforms.mask.Mask - class_path: rslearn.train.transforms.flip.Flip init_args: @@ -129,12 +131,12 @@ data: tags: split: train val_config: - patch_size: 256 + patch_size: 384 groups: ["label", "naip"] tags: split: val test_config: - patch_size: 256 + patch_size: 384 groups: ["label", "naip"] tags: split: val @@ -160,5 +162,9 @@ trainer: save_last: true monitor: val_detect/mAP mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: debug_20240806_newsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_02 +rslp_experiment: debug_20240806_newsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_freeze_00 From 90b26fe02c01a566699a0bf9d78d4692c3248164 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 27 Sep 2024 11:46:38 -0700 Subject: [PATCH 03/58] fix freezing code --- .../wind_turbine/config_flip.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml index 31e7799c..d2aad685 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml @@ -167,4 +167,4 @@ trainer: module_selector: ["model", "encoder", 0, "encoder", "model"] unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: debug_20240806_newsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_freeze_00 +rslp_experiment: debug_20240806_newsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_freeze_01 From 726ad8e1601a9b1725321e7b12de3a7dc727596d Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Mon, 30 Sep 2024 09:50:38 -0700 Subject: [PATCH 04/58] remove unused old config for wind turbine training --- .../wind_turbine/config.yaml | 159 ------------------ 1 file changed, 159 deletions(-) delete mode 100644 convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml deleted file mode 100644 index 8a887aaf..00000000 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml +++ /dev/null @@ -1,159 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 9 - output_layers: [1, 3, 5, 7] - image_channels: 9 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 2 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.0001 - plateau_factor: 0.1 - plateau_patience: 10 - plateau_min_lr: 0 - plateau_cooldown: 0 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: INT32 - image2: - data_type: "raster" - layers: ["sentinel2.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: INT32 - image3: - data_type: "raster" - layers: ["sentinel2.2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: INT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: INT32 - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslearn.train.tasks.detection.DetectionTask - init_args: - property_name: "category" - classes: ["unknown", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - output_selector: image - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 256 - transforms: - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - output_selector: image - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - - class_path: rslp.transforms.mask.Mask - groups: ["label", "naip"] - tags: - old_split: train - val_config: - patch_size: 256 - groups: ["label", "naip"] - tags: - old_split: val - test_config: - patch_size: 256 - groups: ["label", "naip"] - tags: - old_split: val - predict_config: - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max -rslp_project: satlas_wind_turbine -rslp_experiment: debug_20240806_oldsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_00 From 9b0d4bb82572fe931533f8fb5714e64175718d27 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Mon, 30 Sep 2024 10:37:06 -0700 Subject: [PATCH 05/58] add webmercator version of the wind turbine dataset --- .../create_webmercator_rslearn_dataset.py | 182 ++++++++++++++++++ .../wind_turbine/webmercator_config.json | 162 ++++++++++++++++ .../wind_turbine/webmercator_config.yaml | 155 +++++++++++++++ 3 files changed, 499 insertions(+) create mode 100644 convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py create mode 100644 convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json create mode 100644 convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py b/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py new file mode 100644 index 00000000..47e6a19a --- /dev/null +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py @@ -0,0 +1,182 @@ +"""Create WebMercator version of the wind turbine dataset as rslearn dataset. + +This is just used to make sure performance is the same when training in rslearn versus +training in multisat. + +The projection is not set correctly since it's just for testing. +""" + +import argparse +import json +import multiprocessing +import os +import shutil + +import shapely +import tqdm +from rasterio.crs import CRS +from rslearn.dataset import Window +from rslearn.utils import Feature, Projection, STGeometry +from rslearn.utils.mp import star_imap_unordered +from rslearn.utils.vector_format import GeojsonVectorFormat +from upath import UPath + +BANDS = ["tci", "b05", "b06", "b07", "b08", "b11", "b12"] + +# How many images model will be trained with, so can't have examples with fewer than +# this many images. +REQUIRED_IMAGES = 4 + + +def process_example( + ds_path: UPath, + label_dir: str, + image_dir: str, + tile: tuple[int, int], + image_ids: list[str], + split: str, +): + projection = Projection(CRS.from_epsg(3857), 10, -10) + + window_name = f"{tile[0]}_{tile[1]}" + group = "default" + window_root = ds_path / "windows" / group / window_name + window = Window( + path=window_root, + group=group, + name=window_name, + projection=projection, + bounds=[0, 0, 512, 512], + time_range=None, + options={"split": split}, + ) + window.save() + + # Image layers. + for idx, image_id in enumerate(image_ids): + if idx == 0: + layer_name = "sentinel2" + else: + layer_name = f"sentinel2.{idx}" + + for band in BANDS: + src_fname = os.path.join( + image_dir, image_id, band, f"{tile[0]}_{tile[1]}.png" + ) + if band == "tci": + dst_band_name = "R_G_B" + else: + dst_band_name = band + dst_fname = ( + window_root / "layers" / layer_name / dst_band_name / "image.png" + ) + dst_fname.parent.mkdir(parents=True, exist_ok=True) + with open(src_fname, "rb") as src: + with dst_fname.open("wb") as dst: + shutil.copyfileobj(src, dst) + + (window_root / "layers" / layer_name / "completed").touch() + + # Label layer. + features = [] + with open(os.path.join(label_dir, fname)) as f: + for x1, y1, x2, y2, category in json.load(f): + geom = STGeometry(projection, shapely.box(x1, y1, x2, y2), None) + props = dict(category=category) + features.append(Feature(geom, props)) + layer_dir = window_root / "layers" / "label" + GeojsonVectorFormat().encode_vector(layer_dir, projection, features) + (layer_dir / "completed").touch() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--label_dir", + help="The multisat label directory", + type=str, + default="/multisat/labels/renewable_infra_point_naip_supervision", + ) + parser.add_argument( + "--image_dir", + help="The multisat image directory", + type=str, + default="/multisat/mosaic/turbine_naip_supervision/", + ) + parser.add_argument( + "--ds_path", + help="The path to write output rslearn dataset", + type=str, + default="gs://rslearn-eai/datasets/wind_turbine/webmercator_dataset/20240927/", + ) + parser.add_argument( + "--train_split", + help="The JSON file containing train split", + type=str, + default="/multisat/mosaic/splits/turbine_naip_supervision/train.json", + ) + parser.add_argument( + "--workers", + help="Number of parallel workers", + type=int, + default=32, + ) + args = parser.parse_args() + + # Create map from tile to image IDs that have it available. + tile_to_image_ids = {} + for image_id in os.listdir(args.image_dir): + for fname in os.listdir(os.path.join(args.image_dir, image_id, BANDS[0])): + # Make sure the other bands exist. + bands_exist = True + for band in BANDS: + if os.path.exists(os.path.join(args.image_dir, image_id, band, fname)): + continue + bands_exist = False + break + if not bands_exist: + continue + + parts = fname.split(".")[0].split("_") + tile = (int(parts[0]), int(parts[1])) + if tile not in tile_to_image_ids: + tile_to_image_ids[tile] = [] + tile_to_image_ids[tile].append(image_id) + + # Identify tiles in train split. + train_split = set() + with open(args.train_split) as f: + for col, row in json.load(f): + train_split.add((col, row)) + + ds_path = UPath(args.ds_path) + + jobs = [] + for fname in os.listdir(args.label_dir): + parts = fname.split(".")[0].split("_") + tile = (int(parts[0]), int(parts[1])) + image_ids = tile_to_image_ids.get(tile, []) + if len(image_ids) < REQUIRED_IMAGES: + continue + + if tile in train_split: + split = "train" + else: + split = "val" + + jobs.append( + dict( + ds_path=ds_path, + label_dir=args.label_dir, + image_dir=args.image_dir, + tile=tile, + image_ids=image_ids, + split=split, + ) + ) + + p = multiprocessing.Pool(args.workers) + outputs = star_imap_unordered(p, process_example, jobs) + for _ in tqdm.tqdm(outputs, total=len(jobs)): + pass + p.close() diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json b/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json new file mode 100644 index 00000000..91265e41 --- /dev/null +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json @@ -0,0 +1,162 @@ +{ + "layers": { + "sentinel2": { + "type": "raster", + "band_sets": [{ + "dtype": "uint8", + "bands": ["R", "G", "B"], + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b05"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b06"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b07"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b08"], + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b11"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b12"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }] + }, + "sentinel2.1": { + "type": "raster", + "band_sets": [{ + "dtype": "uint8", + "bands": ["R", "G", "B"], + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b05"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b06"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b07"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b08"], + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b11"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b12"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }] + }, + "sentinel2.2": { + "type": "raster", + "band_sets": [{ + "dtype": "uint8", + "bands": ["R", "G", "B"], + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b05"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b06"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b07"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b08"], + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b11"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b12"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }] + }, + "sentinel2.3": { + "type": "raster", + "band_sets": [{ + "dtype": "uint8", + "bands": ["R", "G", "B"], + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b05"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b06"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b07"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b08"], + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b11"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }, { + "dtype": "uint8", + "bands": ["b12"], + "zoom_offset": -1, + "format": {"name": "single_image", "format": "png"} + }] + }, + "label": { + "type": "vector" + }, + "output": { + "type": "vector" + } + }, + "tile_store": { + "name": "file", + "root_dir": "tiles" + } +} diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml new file mode 100644 index 00000000..caf6a743 --- /dev/null +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml @@ -0,0 +1,155 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/webmercator_dataset/20240927/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["R", "G", "B", "b05", "b06", "b07", "b08", "b11", "b12"] + passthrough: true + dtype: INT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["R", "G", "B", "b05", "b06", "b07", "b08", "b11", "b12"] + passthrough: true + dtype: INT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["R", "G", "B", "b05", "b06", "b07", "b08", "b11", "b12"] + passthrough: true + dtype: INT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["R", "G", "B", "b05", "b06", "b07", "b08", "b11", "b12"] + passthrough: true + dtype: INT32 + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 4 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 255 + train_config: + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 255 + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + tags: + split: train + val_config: + patch_size: 512 + tags: + split: val + test_config: + patch_size: 512 + tags: + split: val +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_wind_turbine +rslp_experiment: data_20240927_satlaspretrainold_patch512_flip_4image_01 From 91dd7f97da40a7df9d69f6ba45cd1d0e56f61cf9 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Mon, 30 Sep 2024 11:14:57 -0700 Subject: [PATCH 06/58] Add script to add the bounds metadata for layers using SingleImageRasterFormat --- .../set_single_image_metadata.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py diff --git a/convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py b/convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py new file mode 100644 index 00000000..97044b00 --- /dev/null +++ b/convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py @@ -0,0 +1,33 @@ +import json +import multiprocessing +import sys + +import tqdm +from rslearn.dataset import Dataset, Window +from upath import UPath + + +def handle_window(window: Window): + image_fnames = window.path.glob("layers/*/*/image.png") + for image_fname in image_fnames: + metadata_fname = image_fname.parent / "metadata.json" + if metadata_fname.exists(): + continue + with metadata_fname.open("w") as f: + json.dump({"bounds": window.bounds}, f) + + +def set_single_image_metadata(ds_root: str, workers: int = 32): + ds_path = UPath(ds_root) + dataset = Dataset(ds_path) + windows = dataset.load_windows(show_progress=True, workers=workers) + p = multiprocessing.Pool(workers) + outputs = p.imap_unordered(handle_window, windows) + for _ in tqdm.tqdm(outputs, total=len(windows), desc="Set single image metadata"): + pass + p.close() + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + set_single_image_metadata(ds_root=sys.argv[1]) From f03249ce2241940ee0b3b3cb6954bb596b37c7a1 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Mon, 30 Sep 2024 11:15:45 -0700 Subject: [PATCH 07/58] update name for the 384x384 experiment --- .../wind_turbine/config_flip.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml index d2aad685..67c434fe 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml @@ -167,4 +167,4 @@ trainer: module_selector: ["model", "encoder", 0, "encoder", "model"] unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: debug_20240806_newsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_freeze_01 +rslp_experiment: data_20240806_newsplits_satlaspretrainold_patch384_satlasbands3000_3image_flip_freeze_02 From 036bf0bc214a4135a351128c27fbc48f37d2fedc Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Mon, 30 Sep 2024 15:26:02 -0700 Subject: [PATCH 08/58] Fix wind turbine webmercator training (labels were not being populated) --- .../wind_turbine/create_webmercator_rslearn_dataset.py | 2 +- .../wind_turbine/webmercator_config.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py b/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py index 47e6a19a..2a1e3257 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py @@ -79,7 +79,7 @@ def process_example( # Label layer. features = [] - with open(os.path.join(label_dir, fname)) as f: + with open(os.path.join(label_dir, f"{tile[0]}_{tile[1]}.json")) as f: for x1, y1, x2, y2, category in json.load(f): geom = STGeometry(projection, shapely.box(x1, y1, x2, y2), None) props = dict(category=category) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml index caf6a743..478579f4 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml @@ -77,7 +77,7 @@ data: class_path: rslearn.train.tasks.detection.DetectionTask init_args: property_name: "category" - classes: ["unknown", "turbine"] + classes: ["unknown", "wind_turbine"] box_size: 15 remap_values: [[0, 1], [0, 255]] exclude_by_center: true @@ -152,4 +152,4 @@ trainer: module_selector: ["model", "encoder", 0, "encoder", "model"] unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: data_20240927_satlaspretrainold_patch512_flip_4image_01 +rslp_experiment: data_20240927_satlaspretrainold_patch512_flip_4image_02 From d6f8b0287001418fc4bfaebc1e426f53ce68f2f6 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Wed, 16 Oct 2024 11:43:35 -0400 Subject: [PATCH 09/58] Use six Sentinel-2 images from diverse months of the year --- .../wind_turbine/config.json | 87 +++++++++++++++++-- .../wind_turbine/config_flip.yaml | 80 +++++++++++++---- .../wind_turbine/config_flip_oldsplit.yaml | 36 ++++++-- 3 files changed, 171 insertions(+), 32 deletions(-) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json index 261f95b0..1193e9b3 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json @@ -1,7 +1,8 @@ { "layers": { - "sentinel2": { + "sentinel2_a": { "type": "raster", + "alias": "sentinel2", "band_sets": [{ "dtype": "uint16", "bands": ["B02", "B03", "B04", "B08"] @@ -17,17 +18,93 @@ "data_source": { "name": "rslearn.data_sources.gcp_public_data.Sentinel2", "modality": "L1C", - "index_cache_dir": "/data/favyenb/rslearn_datasets_satlas/solar_farm/cache/sentinel2", + "index_cache_dir": "cache/sentinel2", + "use_rtree_index": false, "max_time_delta": "1d", "query_config": { - "max_matches": 3, + "max_matches": 2, "space_mode": "CONTAINS" }, "sort_by": "cloud_cover", "harmonize": true } }, - "sentinel2.1": { + "sentinel2_b": { + "type": "raster", + "alias": "sentinel2", + "band_sets": [{ + "dtype": "uint16", + "bands": ["B02", "B03", "B04", "B08"] + }, { + "dtype": "uint16", + "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], + "zoom_offset": -1 + }, { + "dtype": "uint16", + "bands": ["B01", "B09", "B10"], + "zoom_offset": -2 + }], + "data_source": { + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "modality": "L1C", + "index_cache_dir": "cache/sentinel2", + "use_rtree_index": false, + "max_time_delta": "1d", + "query_config": { + "max_matches": 2, + "space_mode": "CONTAINS" + }, + "sort_by": "cloud_cover", + "harmonize": true, + "time_offset": "-90d" + } + }, + "sentinel2_c": { + "type": "raster", + "alias": "sentinel2", + "band_sets": [{ + "dtype": "uint16", + "bands": ["B02", "B03", "B04", "B08"] + }, { + "dtype": "uint16", + "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], + "zoom_offset": -1 + }, { + "dtype": "uint16", + "bands": ["B01", "B09", "B10"], + "zoom_offset": -2 + }], + "data_source": { + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "modality": "L1C", + "index_cache_dir": "cache/sentinel2", + "use_rtree_index": false, + "max_time_delta": "1d", + "query_config": { + "max_matches": 2, + "space_mode": "CONTAINS" + }, + "sort_by": "cloud_cover", + "harmonize": true, + "time_offset": "-180d" + } + }, + "sentinel2_a.1": { + "type": "raster", + "band_sets": [{ + "dtype": "uint16", + "bands": ["B02", "B03", "B04", "B08"] + }, { + "dtype": "uint16", + "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], + "zoom_offset": -1 + }, { + "dtype": "uint16", + "bands": ["B01", "B09", "B10"], + "zoom_offset": -2 + }] + }, + "sentinel2_b.1": { "type": "raster", "band_sets": [{ "dtype": "uint16", @@ -42,7 +119,7 @@ "zoom_offset": -2 }] }, - "sentinel2.2": { + "sentinel2_c.1": { "type": "raster", "band_sets": [{ "dtype": "uint16", diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml index 67c434fe..73d0538a 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml @@ -39,26 +39,44 @@ model: data: class_path: rslearn.train.data_module.RslearnDataModule init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ inputs: image1: data_type: "raster" - layers: ["sentinel2"] + layers: ["sentinel2_a"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 image2: data_type: "raster" - layers: ["sentinel2.1"] + layers: ["sentinel2_a.1"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 image3: data_type: "raster" - layers: ["sentinel2.2"] + layers: ["sentinel2_b"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2_b.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2_c"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2_c.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 mask: data_type: "raster" layers: ["mask"] @@ -94,34 +112,58 @@ data: num_workers: 32 default_config: transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: selections: image1: [] image2: [] image3: [] + image4: [] + image5: [] + image6: [] output_selector: image - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 3000] - class_path: rslp.transforms.mask.Mask train_config: patch_size: 384 transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: selections: image1: [] image2: [] image3: [] + image4: [] + image5: [] + image6: [] output_selector: image - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 3000] - class_path: rslp.transforms.mask.Mask - class_path: rslearn.train.transforms.flip.Flip init_args: @@ -153,7 +195,7 @@ trainer: logging_interval: "epoch" - class_path: rslearn.train.prediction_writer.RslearnWriter init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ output_layer: output selector: ["detect"] - class_path: lightning.pytorch.callbacks.ModelCheckpoint @@ -167,4 +209,4 @@ trainer: module_selector: ["model", "encoder", 0, "encoder", "model"] unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: data_20240806_newsplits_satlaspretrainold_patch384_satlasbands3000_3image_flip_freeze_02 +rslp_experiment: data_20241002_satlaspretrainold_patch384_01 diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml index f875da8c..b2b819dd 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml @@ -26,7 +26,7 @@ model: num_channels: 128 num_classes: 2 anchor_sizes: [[32], [64], [128], [256]] - lr: 0.0001 + lr: 0.00002 plateau: true plateau_factor: 0.2 plateau_patience: 2 @@ -46,19 +46,19 @@ data: layers: ["sentinel2"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 image2: data_type: "raster" layers: ["sentinel2.1"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 image3: data_type: "raster" layers: ["sentinel2.2"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 mask: data_type: "raster" layers: ["mask"] @@ -105,9 +105,17 @@ data: init_args: mean: 0 std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] - class_path: rslp.transforms.mask.Mask train_config: - patch_size: 256 + patch_size: 384 transforms: - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: @@ -120,6 +128,14 @@ data: init_args: mean: 0 std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] - class_path: rslp.transforms.mask.Mask - class_path: rslearn.train.transforms.flip.Flip init_args: @@ -129,12 +145,12 @@ data: tags: old_split: train val_config: - patch_size: 256 + patch_size: 384 groups: ["label", "naip"] tags: old_split: val test_config: - patch_size: 256 + patch_size: 384 groups: ["label", "naip"] tags: old_split: val @@ -160,5 +176,9 @@ trainer: save_last: true monitor: val_detect/mAP mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: debug_20240806_oldsplits_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_02 +rslp_experiment: data_20240806_oldsplits_satlaspretrainold_patch384_satlasnorm_3image_flip_freeze_03 From 8cc663a60e1c692408bdcdbe78d07501c96cd826 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 18 Oct 2024 13:50:46 -0700 Subject: [PATCH 10/58] Add marine infrastructure dataset config and model config. --- .../marine_infra/config.json | 79 ++++++++ .../marine_infra/config.yaml | 186 ++++++++++++++++++ 2 files changed, 265 insertions(+) create mode 100644 convert_satlas_webmercator_to_rslearn/marine_infra/config.json create mode 100644 convert_satlas_webmercator_to_rslearn/marine_infra/config.yaml diff --git a/convert_satlas_webmercator_to_rslearn/marine_infra/config.json b/convert_satlas_webmercator_to_rslearn/marine_infra/config.json new file mode 100644 index 00000000..0611096c --- /dev/null +++ b/convert_satlas_webmercator_to_rslearn/marine_infra/config.json @@ -0,0 +1,79 @@ +{ + "layers": { + "sentinel2": { + "type": "raster", + "band_sets": [{ + "dtype": "uint16", + "bands": ["B02", "B03", "B04", "B08"] + }, { + "dtype": "uint16", + "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], + "zoom_offset": -1 + }, { + "dtype": "uint16", + "bands": ["B01", "B09", "B10"], + "zoom_offset": -2 + }], + "data_source": { + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "modality": "L1C", + "index_cache_dir": "/data/favyenb/rslearn_datasets_satlas/solar_farm/cache/sentinel2", + "max_time_delta": "1d", + "query_config": { + "max_matches": 3, + "space_mode": "CONTAINS" + }, + "sort_by": "cloud_cover", + "harmonize": true + } + }, + "sentinel2.1": { + "type": "raster", + "band_sets": [{ + "dtype": "uint16", + "bands": ["B02", "B03", "B04", "B08"] + }, { + "dtype": "uint16", + "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], + "zoom_offset": -1 + }, { + "dtype": "uint16", + "bands": ["B01", "B09", "B10"], + "zoom_offset": -2 + }] + }, + "sentinel2.2": { + "type": "raster", + "band_sets": [{ + "dtype": "uint16", + "bands": ["B02", "B03", "B04", "B08"] + }, { + "dtype": "uint16", + "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], + "zoom_offset": -1 + }, { + "dtype": "uint16", + "bands": ["B01", "B09", "B10"], + "zoom_offset": -2 + }] + }, + "label": { + "type": "vector" + }, + "mask": { + "type": "raster", + "band_sets": [{ + "dtype": "uint8", + "bands": ["mask"], + "format": {"name": "single_image", "format": "png"} + }] + }, + "output": { + "type": "vector" + } + }, + "tile_store": { + "name": "file", + "root_dir": "tiles" + } +} diff --git a/convert_satlas_webmercator_to_rslearn/marine_infra/config.yaml b/convert_satlas_webmercator_to_rslearn/marine_infra/config.yaml new file mode 100644 index 00000000..4a08d39c --- /dev/null +++ b/convert_satlas_webmercator_to_rslearn/marine_infra/config.yaml @@ -0,0 +1,186 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 3 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "platform", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + skip_unknown_categories: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + tags: + split: train + val_config: + patch_size: 512 + tags: + split: val + test_config: + patch_size: 512 + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_marine_infra +rslp_experiment: data_20241002_satlaspretrainold_patch512_satlasnorm_3image_01 From 2b299493e99aeaa716c275a8b36a6c6337e17db3 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 18 Oct 2024 13:51:22 -0700 Subject: [PATCH 11/58] add files needed by the model config for training --- rslp/satlas_marine_infra/__init__.py | 1 + rslp/satlas_marine_infra/train.py | 49 ++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 rslp/satlas_marine_infra/__init__.py create mode 100644 rslp/satlas_marine_infra/train.py diff --git a/rslp/satlas_marine_infra/__init__.py b/rslp/satlas_marine_infra/__init__.py new file mode 100644 index 00000000..ab71e155 --- /dev/null +++ b/rslp/satlas_marine_infra/__init__.py @@ -0,0 +1 @@ +"""Satlas marine infrastructure.""" diff --git a/rslp/satlas_marine_infra/train.py b/rslp/satlas_marine_infra/train.py new file mode 100644 index 00000000..170298d5 --- /dev/null +++ b/rslp/satlas_marine_infra/train.py @@ -0,0 +1,49 @@ +"""Remaps categories for marine infrastructure training.""" + +from typing import Any + +import torch +from rslearn.train.tasks.detection import DetectionTask +from rslearn.utils import Feature + +CATEGORY_MAPPING = { + "power": "platform", +} + + +class MarineInfraTask(DetectionTask): + """Marine infrastructure detection task. + + We just add a remapping pre-processing. + """ + + def process_inputs( + self, + raw_inputs: dict[str, torch.Tensor | list[Feature]], + metadata: dict[str, Any], + load_targets: bool = True, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Processes the data into targets. + + Args: + raw_inputs: raster or vector data to process + metadata: metadata about the patch being read + load_targets: whether to load the targets or only inputs + + Returns: + tuple (input_dict, target_dict) containing the processed inputs and targets + that are compatible with both metrics and loss functions + """ + if not load_targets: + return {}, {} + + for feat in raw_inputs["targets"]: + if self.property_name not in feat["properties"]: + continue + properties = feat["properties"] + category = properties[self.property_name] + if category not in CATEGORY_MAPPING: + continue + properties[self.property_name] = CATEGORY_MAPPING[category] + + return super().process_inputs(raw_inputs, metadata, load_targets) From 82a9149c14d1743e0098d0691fe16fe851363526 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 22 Oct 2024 13:59:00 -0700 Subject: [PATCH 12/58] maybe close to same performance --- .../compare_webmercator_utm_together.py | 28 +++++++++ .../wind_turbine/config_flip.yaml | 2 +- .../wind_turbine/config_flip_oldsplit.yaml | 62 ++++++++++++++----- 3 files changed, 74 insertions(+), 18 deletions(-) create mode 100644 convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py b/convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py new file mode 100644 index 00000000..6bbdc281 --- /dev/null +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py @@ -0,0 +1,28 @@ +"""Simple script to copy the images. + +Just need to run `model predict` with visualize_dir set to utm and webm respectively. +""" + +import os +import shutil + +webm_images = {} +for fname in os.listdir("webm"): + if not fname.endswith("_gt.png"): + continue + parts = fname.split("_gt.png")[0].split("_") + tile = (int(parts[0]), int(parts[1])) + webm_images[tile] = fname +utm_images = {} +for fname in os.listdir("utm"): + if not fname.endswith("_gt.png"): + continue + parts = fname.split("_gt.png")[0].split("_") + tile = (int(parts[0]) // 512, int(parts[1]) // 512) + utm_images[tile] = fname +good_keys = set(webm_images.keys()).intersection(utm_images.keys()) +for tile in good_keys: + shutil.copyfile( + f"webm/{webm_images[tile]}", f"out/{tile[0]}_{tile[1]}_webmercator.png" + ) + shutil.copyfile(f"utm/{utm_images[tile]}", f"out/{tile[0]}_{tile[1]}_utm.png") diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml index 73d0538a..d6d022bb 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml @@ -209,4 +209,4 @@ trainer: module_selector: ["model", "encoder", 0, "encoder", "model"] unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: data_20241002_satlaspretrainold_patch384_01 +rslp_experiment: data_20241002_satlaspretrainold_patch384_03 diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml index b2b819dd..1648b771 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml @@ -39,23 +39,41 @@ model: data: class_path: rslearn.train.data_module.RslearnDataModule init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ inputs: image1: data_type: "raster" - layers: ["sentinel2"] + layers: ["sentinel2_a"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 image2: data_type: "raster" - layers: ["sentinel2.1"] + layers: ["sentinel2_a.1"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 image3: data_type: "raster" - layers: ["sentinel2.2"] + layers: ["sentinel2_b"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2_b.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2_c"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2_c.1"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 @@ -94,48 +112,58 @@ data: num_workers: 32 default_config: transforms: - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - output_selector: image - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 std: 3000 valid_range: [0, 1] bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 std: 8160 valid_range: [0, 1] bands: [3, 4, 5, 6, 7, 8] - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 384 - transforms: + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: selections: image1: [] image2: [] image3: [] + image4: [] + image5: [] + image6: [] output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 std: 3000 valid_range: [0, 1] bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 std: 8160 valid_range: [0, 1] bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image - class_path: rslp.transforms.mask.Mask - class_path: rslearn.train.transforms.flip.Flip init_args: @@ -167,7 +195,7 @@ trainer: logging_interval: "epoch" - class_path: rslearn.train.prediction_writer.RslearnWriter init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ output_layer: output selector: ["detect"] - class_path: lightning.pytorch.callbacks.ModelCheckpoint @@ -181,4 +209,4 @@ trainer: module_selector: ["model", "encoder", 0, "encoder", "model"] unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: data_20240806_oldsplits_satlaspretrainold_patch384_satlasnorm_3image_flip_freeze_03 +rslp_experiment: data_20241002_satlaspretrainold_patch384_oldsplits_03 From 24e609364854a08be56b8e3ecac914c89d755bc8 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 22 Oct 2024 14:03:51 -0700 Subject: [PATCH 13/58] json formatting --- .../wind_turbine/config.json | 388 +++++++++----- .../wind_turbine/webmercator_config.json | 494 ++++++++++++------ 2 files changed, 591 insertions(+), 291 deletions(-) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json index 866cb416..eee8273f 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json @@ -1,141 +1,261 @@ { - "layers": { - "sentinel2_a": { - "type": "raster", - "alias": "sentinel2", - "band_sets": [{ - "dtype": "uint16", - "bands": ["B02", "B03", "B04", "B08"] - }, { - "dtype": "uint16", - "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], - "zoom_offset": -1 - }, { - "dtype": "uint16", - "bands": ["B01", "B09", "B10"], - "zoom_offset": -2 - }], - "data_source": { - "name": "rslearn.data_sources.gcp_public_data.Sentinel2", - "modality": "L1C", - "index_cache_dir": "cache/sentinel2", - "use_rtree_index": false, - "max_time_delta": "1d", - "query_config": { - "max_matches": 2, - "space_mode": "CONTAINS" - }, - "sort_by": "cloud_cover", - "harmonize": true - } + "layers": { + "label": { + "type": "vector" + }, + "sentinel2_a": { + "alias": "sentinel2", + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" }, - "sentinel2_b": { - "type": "raster", - "alias": "sentinel2", - "band_sets": [{ - "dtype": "uint16", - "bands": ["B02", "B03", "B04", "B08"] - }, { - "dtype": "uint16", - "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], - "zoom_offset": -1 - }, { - "dtype": "uint16", - "bands": ["B01", "B09", "B10"], - "zoom_offset": -2 - }], - "data_source": { - "name": "rslearn.data_sources.gcp_public_data.Sentinel2", - "modality": "L1C", - "index_cache_dir": "cache/sentinel2", - "use_rtree_index": false, - "max_time_delta": "1d", - "query_config": { - "max_matches": 2, - "space_mode": "CONTAINS" - }, - "sort_by": "cloud_cover", - "harmonize": true, - "time_offset": "-90d" - } + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 }, - "sentinel2_c": { - "type": "raster", - "alias": "sentinel2", - "band_sets": [{ - "dtype": "uint16", - "bands": ["B02", "B03", "B04", "B08"] - }, { - "dtype": "uint16", - "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], - "zoom_offset": -1 - }, { - "dtype": "uint16", - "bands": ["B01", "B09", "B10"], - "zoom_offset": -2 - }], - "data_source": { - "name": "rslearn.data_sources.gcp_public_data.Sentinel2", - "modality": "L1C", - "index_cache_dir": "cache/sentinel2", - "use_rtree_index": false, - "max_time_delta": "1d", - "query_config": { - "max_matches": 2, - "space_mode": "CONTAINS" - }, - "sort_by": "cloud_cover", - "harmonize": true, - "time_offset": "-180d" - } + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "1d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "query_config": { + "max_matches": 2, + "space_mode": "CONTAINS" }, - "sentinel2_a.1": { - "type": "raster", - "band_sets": [{ - "dtype": "uint16", - "bands": ["B02", "B03", "B04", "B08"] - }, { - "dtype": "uint16", - "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], - "zoom_offset": -1 - }, { - "dtype": "uint16", - "bands": ["B01", "B09", "B10"], - "zoom_offset": -2 - }] - }, - "sentinel2_b.1": { - "type": "raster", - "band_sets": [{ - "dtype": "uint16", - "bands": ["B02", "B03", "B04", "B08"] - }, { - "dtype": "uint16", - "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], - "zoom_offset": -1 - }, { - "dtype": "uint16", - "bands": ["B01", "B09", "B10"], - "zoom_offset": -2 - }] - }, - "sentinel2_c.1": { - "type": "raster", - "band_sets": [{ - "dtype": "uint16", - "bands": ["B02", "B03", "B04", "B08"] - }, { - "dtype": "uint16", - "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], - "zoom_offset": -1 - }, { - "dtype": "uint16", - "bands": ["B01", "B09", "B10"], - "zoom_offset": -2 - }] - }, - "label": { - "type": "vector" + "sort_by": "cloud_cover", + "use_rtree_index": false + }, + "type": "raster" + }, + "sentinel2_a.1": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "type": "raster" + }, + "sentinel2_b": { + "alias": "sentinel2", + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "1d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "query_config": { + "max_matches": 2, + "space_mode": "CONTAINS" + }, + "sort_by": "cloud_cover", + "time_offset": "-90d", + "use_rtree_index": false + }, + "type": "raster" + }, + "sentinel2_b.1": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "type": "raster" + }, + "sentinel2_c": { + "alias": "sentinel2", + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "1d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "query_config": { + "max_matches": 2, + "space_mode": "CONTAINS" + }, + "sort_by": "cloud_cover", + "time_offset": "-180d", + "use_rtree_index": false + }, + "type": "raster" + }, + "sentinel2_c.1": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "type": "raster" } }, "tile_store": { diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json b/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json index 91265e41..707e5593 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json @@ -1,162 +1,342 @@ { - "layers": { - "sentinel2": { - "type": "raster", - "band_sets": [{ - "dtype": "uint8", - "bands": ["R", "G", "B"], - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b05"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b06"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b07"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b08"], - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b11"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b12"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }] - }, - "sentinel2.1": { - "type": "raster", - "band_sets": [{ - "dtype": "uint8", - "bands": ["R", "G", "B"], - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b05"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b06"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b07"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b08"], - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b11"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b12"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }] - }, - "sentinel2.2": { - "type": "raster", - "band_sets": [{ - "dtype": "uint8", - "bands": ["R", "G", "B"], - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b05"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b06"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b07"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b08"], - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b11"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b12"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }] - }, - "sentinel2.3": { - "type": "raster", - "band_sets": [{ - "dtype": "uint8", - "bands": ["R", "G", "B"], - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b05"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b06"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b07"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b08"], - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b11"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }, { - "dtype": "uint8", - "bands": ["b12"], - "zoom_offset": -1, - "format": {"name": "single_image", "format": "png"} - }] - }, - "label": { - "type": "vector" - }, - "output": { - "type": "vector" + "layers": { + "label": { + "type": "vector" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b05" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b06" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b07" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b08" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b11" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b12" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + } + ], + "type": "raster" + }, + "sentinel2.1": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b05" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b06" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b07" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b08" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b11" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b12" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 } + ], + "type": "raster" }, - "tile_store": { - "name": "file", - "root_dir": "tiles" + "sentinel2.2": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b05" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b06" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b07" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b08" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b11" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b12" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + } + ], + "type": "raster" + }, + "sentinel2.3": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b05" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b06" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b07" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b08" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b11" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b12" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + } + ], + "type": "raster" } + }, + "tile_store": { + "name": "file", + "root_dir": "tiles" + } } From cb04f2d0b62d9b3a1acfc2ee87caa4c6d1a81ddb Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 22 Oct 2024 14:06:00 -0700 Subject: [PATCH 14/58] fix missing sections in config.json --- .../wind_turbine/config.json | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json index eee8273f..3cace116 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json @@ -3,6 +3,24 @@ "label": { "type": "vector" }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, "sentinel2_a": { "alias": "sentinel2", "band_sets": [ From 91274c163dbe99d6c61b295c898ea74418560ed6 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Wed, 23 Oct 2024 14:50:57 -0700 Subject: [PATCH 15/58] Add job launcher and prediction pipeline for Satlas applications. --- .../marine_infra/config.json | 79 ------ data/satlas/marine_infra/config.json | 146 +++++++++++ .../satlas}/marine_infra/config.yaml | 25 +- rslp/main.py | 33 +++ rslp/satlas/__init__.py | 14 ++ rslp/satlas/job_launcher.py | 237 ++++++++++++++++++ rslp/satlas/predict_pipeline.py | 127 ++++++++++ rslp/{satlas_marine_infra => satlas}/train.py | 4 +- rslp/satlas_marine_infra/__init__.py | 1 - 9 files changed, 583 insertions(+), 83 deletions(-) delete mode 100644 convert_satlas_webmercator_to_rslearn/marine_infra/config.json create mode 100644 data/satlas/marine_infra/config.json rename {convert_satlas_webmercator_to_rslearn => data/satlas}/marine_infra/config.yaml (87%) create mode 100644 rslp/satlas/__init__.py create mode 100644 rslp/satlas/job_launcher.py create mode 100644 rslp/satlas/predict_pipeline.py rename rslp/{satlas_marine_infra => satlas}/train.py (93%) delete mode 100644 rslp/satlas_marine_infra/__init__.py diff --git a/convert_satlas_webmercator_to_rslearn/marine_infra/config.json b/convert_satlas_webmercator_to_rslearn/marine_infra/config.json deleted file mode 100644 index 0611096c..00000000 --- a/convert_satlas_webmercator_to_rslearn/marine_infra/config.json +++ /dev/null @@ -1,79 +0,0 @@ -{ - "layers": { - "sentinel2": { - "type": "raster", - "band_sets": [{ - "dtype": "uint16", - "bands": ["B02", "B03", "B04", "B08"] - }, { - "dtype": "uint16", - "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], - "zoom_offset": -1 - }, { - "dtype": "uint16", - "bands": ["B01", "B09", "B10"], - "zoom_offset": -2 - }], - "data_source": { - "name": "rslearn.data_sources.gcp_public_data.Sentinel2", - "modality": "L1C", - "index_cache_dir": "/data/favyenb/rslearn_datasets_satlas/solar_farm/cache/sentinel2", - "max_time_delta": "1d", - "query_config": { - "max_matches": 3, - "space_mode": "CONTAINS" - }, - "sort_by": "cloud_cover", - "harmonize": true - } - }, - "sentinel2.1": { - "type": "raster", - "band_sets": [{ - "dtype": "uint16", - "bands": ["B02", "B03", "B04", "B08"] - }, { - "dtype": "uint16", - "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], - "zoom_offset": -1 - }, { - "dtype": "uint16", - "bands": ["B01", "B09", "B10"], - "zoom_offset": -2 - }] - }, - "sentinel2.2": { - "type": "raster", - "band_sets": [{ - "dtype": "uint16", - "bands": ["B02", "B03", "B04", "B08"] - }, { - "dtype": "uint16", - "bands": ["B05", "B06", "B07", "B8A", "B11", "B12"], - "zoom_offset": -1 - }, { - "dtype": "uint16", - "bands": ["B01", "B09", "B10"], - "zoom_offset": -2 - }] - }, - "label": { - "type": "vector" - }, - "mask": { - "type": "raster", - "band_sets": [{ - "dtype": "uint8", - "bands": ["mask"], - "format": {"name": "single_image", "format": "png"} - }] - }, - "output": { - "type": "vector" - } - }, - "tile_store": { - "name": "file", - "root_dir": "tiles" - } -} diff --git a/data/satlas/marine_infra/config.json b/data/satlas/marine_infra/config.json new file mode 100644 index 00000000..fa9f2cf9 --- /dev/null +++ b/data/satlas/marine_infra/config.json @@ -0,0 +1,146 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "1d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "query_config": { + "max_matches": 3 + }, + "sort_by": "cloud_cover", + "use_rtree_index": false + }, + "type": "raster" + }, + "sentinel2.1": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "type": "raster" + }, + "sentinel2.2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "type": "raster" + } + }, + "tile_store": { + "name": "file", + "root_dir": "tiles" + } +} diff --git a/convert_satlas_webmercator_to_rslearn/marine_infra/config.yaml b/data/satlas/marine_infra/config.yaml similarity index 87% rename from convert_satlas_webmercator_to_rslearn/marine_infra/config.yaml rename to data/satlas/marine_infra/config.yaml index 4a08d39c..598f6752 100644 --- a/convert_satlas_webmercator_to_rslearn/marine_infra/config.yaml +++ b/data/satlas/marine_infra/config.yaml @@ -65,6 +65,7 @@ data: bands: ["mask"] passthrough: true dtype: FLOAT32 + is_target: true targets: data_type: "vector" layers: ["label"] @@ -74,7 +75,7 @@ data: init_args: tasks: detect: - class_path: rslearn.train.tasks.detection.DetectionTask + class_path: rslp.satlas.train.MarineInfraTask init_args: property_name: "category" classes: ["unknown", "platform", "turbine"] @@ -157,6 +158,28 @@ data: tags: split: val predict_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image groups: ["predict"] load_all_patches: true skip_targets: true diff --git a/rslp/main.py b/rslp/main.py index da0cd0e8..787a5eaa 100644 --- a/rslp/main.py +++ b/rslp/main.py @@ -5,13 +5,39 @@ import logging import multiprocessing import sys +from datetime import datetime import dotenv import jsonargparse +import jsonargparse.typing logging.basicConfig() +def datetime_serializer(v: datetime) -> str: + """Serialize datetime for jsonargparse. + + Args: + v: the datetime object. + + Returns: + the datetime encoded to string + """ + return v.isoformat() + + +def datetime_deserializer(v: str) -> datetime: + """Deserialize datetime for jsonargparse. + + Args: + v: the encoded datetime. + + Returns: + the decoded datetime object + """ + return datetime.fromisoformat(v) + + def main() -> None: """Main entrypoint function for rslp.""" dotenv.load_dotenv() @@ -22,6 +48,13 @@ def main() -> None: module = importlib.import_module(f"rslp.{args.project}") workflow_fn = module.workflows[args.workflow] + + # Setup jsonargparse. + jsonargparse.typing.register_type( + datetime, datetime_serializer, datetime_deserializer + ) + + # Parse arguments and run function. jsonargparse.CLI(workflow_fn, args=sys.argv[3:]) diff --git a/rslp/satlas/__init__.py b/rslp/satlas/__init__.py new file mode 100644 index 00000000..756cb533 --- /dev/null +++ b/rslp/satlas/__init__.py @@ -0,0 +1,14 @@ +"""Satlas batch jobs. + +Specifically, training and inference for these fine-tuned models on satlas.allen.ai: +- Marine infrastructure +- On-shore wind turbines +- Solar farms +- Tree cover +""" + +from .predict_pipeline import predict_pipeline + +workflows = { + "predict": predict_pipeline, +} diff --git a/rslp/satlas/job_launcher.py b/rslp/satlas/job_launcher.py new file mode 100644 index 00000000..9af927c6 --- /dev/null +++ b/rslp/satlas/job_launcher.py @@ -0,0 +1,237 @@ +"""Launch Satlas prediction jobs on Beaker.""" + +import json +import multiprocessing +import os +import random +import uuid +from datetime import datetime + +import rslearn.utils.get_utm_ups_crs +import shapely +import tqdm +from beaker import ( + Beaker, + Constraints, + DataMount, + DataSource, + EnvVar, + ExperimentSpec, + Priority, + TaskResources, +) +from rasterio.crs import CRS +from rslearn.const import WGS84_PROJECTION +from rslearn.utils.geometry import PixelBounds, Projection, STGeometry + +from .predict_pipeline import Application, get_output_fname + +WORKSPACE = "ai2/earth-systems" +BUDGET = "ai2/d5" +IMAGE_NAME = "favyen/rslearn" +TILE_SIZE = 16384 +RESOLUTION = 10 + + +class Task: + """Represents a task that will correspond to one Beaker job.""" + + def __init__( + self, + application: Application, + projection: Projection, + bounds: PixelBounds, + time_range: tuple[datetime, datetime], + out_path: str, + ) -> None: + """Create a new Task. + + Args: + application: the application to run + projection: the projection of the tile + bounds: the bounds of the tile + time_range: the time range to process + out_path: where to write outputs + """ + self.application = application + self.projection = projection + self.bounds = bounds + self.time_range = time_range + self.out_path = out_path + + +def launch_job(task: Task) -> None: + """Launch job for this task. + + Args: + task: the Task object for which to create a job. + """ + beaker = Beaker.from_env(default_workspace=WORKSPACE) + + with beaker.session(): + env_vars = [ + EnvVar( + name="GOOGLE_APPLICATION_CREDENTIALS", # nosec + value="/etc/credentials/gcp_credentials.json", # nosec + ), + EnvVar( + name="GCLOUD_PROJECT", # nosec + value="skylight-proto-1", # nosec + ), + EnvVar( + name="RSLP_BUCKET", + value=os.environ["RSLP_BUCKET"], + ), + EnvVar( + name="MKL_THREADING_LAYER", + value="GNU", + ), + ] + + experiment_name = ( + f"satlas_{task.application.value}_{str(task.projection.crs)}_" + + f"{task.bounds[0]}_{task.bounds[1]}" + ) + + spec = ExperimentSpec.new( + budget=BUDGET, + description=experiment_name, + beaker_image=IMAGE_NAME, + priority=Priority.low, + command=["python", "-m", "rslp.main"], + arguments=[ + "satlas", + "predict", + task.application, + json.dumps(task.projection.serialize()), + json.dumps(task.bounds), + json.dumps( + [task.time_range[0].isoformat(), task.time_range[1].isoformat()] + ), + task.out_path, + "/tmp/scratch/", + ], + constraints=Constraints( + cluster=[ + "ai2/jupiter-cirrascale-2", + "ai2/neptune-cirrascale", + "ai2/saturn-cirrascale", + "ai2/pluto-cirrascale", + "ai2/general-cirrascale", + "ai2/prior-cirrascale", + "ai2/prior-elanding", + ] + ), + preemptible=True, + datasets=[ + DataMount( + source=DataSource(secret="RSLEARN_GCP_CREDENTIALS"), # nosec + mount_path="/etc/credentials/gcp_credentials.json", # nosec + ), + ], + env_vars=env_vars, + resources=TaskResources(gpu_count=1), + ) + unique_id = str(uuid.uuid4())[0:8] + beaker.experiment.create(experiment_name + "_" + unique_id, spec) + + +def check_task_done(task: Task) -> tuple[Task, bool]: + """Checks whether this task is done processing already. + + It is determined based on existence of output file for the task. + + Args: + task: the task. + + Returns: + whether the task was completed + """ + out_fname = get_output_fname( + task.application, task.out_path, task.projection, task.bounds + ) + return task, out_fname.exists() + + +def launch_jobs( + application: Application, + time_range: tuple[datetime, datetime], + out_path: str, + epsg_code: int | None = None, + wgs84_bounds: PixelBounds | None = None, + count: int | None = None, +) -> None: + """Launch Beaker jobs for Satlas prediction. + + Args: + application: which application to run. + time_range: the time range to run within. Must have timezone. + out_path: the output path. It should be specific to the time range. + epsg_code: limit tasks to this UTM zone (specified by its EPSG code), default + run in all UTM zones. + wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. + count: only run up to this many tasks. + """ + # Generate tasks. + if epsg_code: + utm_zones = [CRS.from_epsg(epsg_code)] + else: + for epsg_code in range(32601, 32661): + utm_zones.append(CRS.from_epsg(epsg_code)) + for epsg_code in range(32701, 32761): + utm_zones.append(CRS.from_epsg(epsg_code)) + + tasks: list[Task] = [] + for utm_zone in utm_zones: + zone_bounds = rslearn.utils.get_utm_ups_crs.get_proj_bounds(utm_zone) + projection = Projection(utm_zone, RESOLUTION, -RESOLUTION) + for col in range(zone_bounds[0], zone_bounds[2], TILE_SIZE): + for row in range(zone_bounds[1], zone_bounds[3], TILE_SIZE): + if wgs84_bounds is not None: + # Check if the longitude/latitude of this task is in wgs84_bounds. + src_geom = STGeometry(projection, shapely.Point(col, row), None) + wgs84_point = src_geom.to_projection(WGS84_PROJECTION).shp + if wgs84_point.x < wgs84_bounds[0]: + continue + if wgs84_point.x >= wgs84_bounds[2]: + continue + if wgs84_point.y < wgs84_bounds[1]: + continue + if wgs84_point.y >= wgs84_bounds[3]: + continue + + tasks.append( + Task( + application=application, + projection=projection, + bounds=(col, row, col + TILE_SIZE, row + TILE_SIZE), + time_range=time_range, + out_path=out_path, + ) + ) + + # See which tasks are not done yet. + p = multiprocessing.Pool(32) + outputs = p.imap_unordered(check_task_done, tasks) + + pending_tasks: list[Task] = [] + for task, is_done in tqdm.tqdm( + outputs, desc="Check which tasks are completed", total=len(tasks) + ): + if is_done: + continue + pending_tasks.append(task) + + p.close() + + # Run up to count of them. + if count is not None and len(pending_tasks) > count: + run_tasks = random.sample(pending_tasks, count) + else: + run_tasks = pending_tasks + + print( + f"Got {len(tasks)} total tasks, {len(pending_tasks)} pending, running {len(run_tasks)} of them" + ) + for task in tqdm.tqdm(run_tasks, desc="Starting Beaker jobs"): + launch_job(task) diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py new file mode 100644 index 00000000..bd9e5ada --- /dev/null +++ b/rslp/satlas/predict_pipeline.py @@ -0,0 +1,127 @@ +"""Prediction pipeline for Satlas models.""" + +import json +import shutil +from datetime import datetime +from enum import Enum + +from rslearn.dataset import Window +from rslearn.utils.geometry import PixelBounds, Projection +from upath import UPath + +from rslp.utils.rslearn import materialize_dataset, run_model_predict + +DATASET_CONFIG_FNAME = "convert_satlas_webmercator_to_rslearn/{application}/config.json" +MODEL_CONFIG_FNAME = "convert_satlas_webmercator_to_rslearn/{application}/config.yaml" + + +class Application(Enum): + """Specifies the various Satlas applications.""" + + SOLAR_FARM = "solar_farm" + WIND_TURBINE = "wind_turbine" + MARINE_INFRA = "marine_infra" + TREE_COVER = "tree_cover" + + +APP_IS_RASTER = { + Application.SOLAR_FARM: True, + Application.WIND_TURBINE: False, + Application.MARINE_INFRA: False, + Application.SOLAR_FARM: True, +} + + +def get_output_fname( + application: Application, out_path: str, projection: Projection, bounds: PixelBounds +) -> UPath: + """Get output filename to use for this application and task. + + Args: + application: the application. + out_path: the output path. + projection: the projection of this task. + bounds: the bounds of this task. + + Returns: + the output filename. + """ + if APP_IS_RASTER[application]: + out_fname = ( + UPath(out_path) / f"{str(projection.crs)}_{bounds[0]}_{bounds[1]}.tif" + ) + else: + out_fname = ( + UPath(out_path) / f"{str(projection.crs)}_{bounds[0]}_{bounds[1]}.geojson" + ) + return out_fname + + +def predict_pipeline( + application: Application, + projection_json: str, + bounds: PixelBounds, + time_range: tuple[datetime, datetime], + out_path: str, + scratch_path: str, +) -> None: + """Compute outputs of a Satlas model on this tile. + + The tile is one part of a UTM zone. + + Args: + application: the application for which to compute outputs. + projection_json: JSON-encoded projection, normally a UTM zone with 10 m/pixel + resolution. + bounds: pixel coordinates within the projection on which to compute outputs. + time_range: time range to apply model on. + out_path: where to write the outputs. It will either be a GeoTIFF or GeoJSON, + named based on the bounds. + scratch_path: where to store the dataset. + """ + dataset_config_fname = DATASET_CONFIG_FNAME.format(application=application.value) + model_config_fname = MODEL_CONFIG_FNAME.format(application=application.value) + + # Check if the output was already computed. + projection = Projection.deserialize(json.loads(projection_json)) + out_fname = get_output_fname(application, out_path, projection, bounds) + if out_fname.exists(): + print(f"output file {out_fname} already exists") + return + + # Initialize an rslearn dataset. + ds_path = UPath(scratch_path) + ds_path.mkdir(parents=True, exist_ok=True) + with open(dataset_config_fname) as f: + ds_cfg = json.load(f) + with (ds_path / "config.json").open("w") as f: + json.dump(ds_cfg, f) + + # Create a window corresponding to the specified projection and bounds. + group = "predict" + window_path = ds_path / "windows" / group / "default" + window = Window( + path=window_path, + group=group, + name="default", + projection=projection, + bounds=bounds, + time_range=time_range, + ) + window.save() + + # Populate the window. + print("materialize dataset") + materialize_dataset(ds_path, group=group) + + # Run the model. + run_model_predict(model_config_fname, ds_path) + + if APP_IS_RASTER[application]: + src_fname = window_path / "layers" / "output" / "output" / "geotiff.tif" + else: + src_fname = window_path / "layers" / "output" / "data.geojson" + + with src_fname.open("rb") as src: + with out_fname.open("wb") as dst: + shutil.copyfileobj(src, dst) diff --git a/rslp/satlas_marine_infra/train.py b/rslp/satlas/train.py similarity index 93% rename from rslp/satlas_marine_infra/train.py rename to rslp/satlas/train.py index 170298d5..e6c905e6 100644 --- a/rslp/satlas_marine_infra/train.py +++ b/rslp/satlas/train.py @@ -1,4 +1,4 @@ -"""Remaps categories for marine infrastructure training.""" +"""Satlas custom training code.""" from typing import Any @@ -14,7 +14,7 @@ class MarineInfraTask(DetectionTask): """Marine infrastructure detection task. - We just add a remapping pre-processing. + We just add a category remapping pre-processing. """ def process_inputs( diff --git a/rslp/satlas_marine_infra/__init__.py b/rslp/satlas_marine_infra/__init__.py deleted file mode 100644 index ab71e155..00000000 --- a/rslp/satlas_marine_infra/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Satlas marine infrastructure.""" From ac8eb3299c2669b3b34e69fb95079378ffbbd848 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Wed, 6 Nov 2024 14:49:58 -0800 Subject: [PATCH 16/58] marine infra updates --- data/satlas/marine_infra/config.json | 87 +++- data/satlas/marine_infra/config.yaml | 35 +- rslp/satlas/README.md | 7 + rslp/satlas/__init__.py | 7 +- rslp/satlas/job_launcher.py | 143 +++--- rslp/satlas/predict_pipeline.py | 285 +++++++++++- .../scripts/smooth_point_labels_viterbi.go | 419 ++++++++++++++++++ rslp/satlas/train.py | 7 +- 8 files changed, 891 insertions(+), 99 deletions(-) create mode 100644 rslp/satlas/README.md create mode 100644 rslp/satlas/scripts/smooth_point_labels_viterbi.go diff --git a/data/satlas/marine_infra/config.json b/data/satlas/marine_infra/config.json index fa9f2cf9..c72ff59a 100644 --- a/data/satlas/marine_infra/config.json +++ b/data/satlas/marine_infra/config.json @@ -19,9 +19,14 @@ "type": "raster" }, "output": { + "format": { + "coordinate_mode": "pixel", + "name": "geojson" + }, "type": "vector" }, - "sentinel2": { + "sentinel2_a": { + "alias": "sentinel2", "band_sets": [ { "bands": [ @@ -55,20 +60,67 @@ } ], "data_source": { + "duration": "30d", "harmonize": true, "index_cache_dir": "cache/sentinel2", - "max_time_delta": "1d", + "max_time_delta": "0d", "modality": "L1C", "name": "rslearn.data_sources.gcp_public_data.Sentinel2", - "query_config": { - "max_matches": 3 + "sort_by": "cloud_cover", + "time_offset": "60d", + "use_rtree_index": false + }, + "type": "raster" + }, + "sentinel2_b": { + "alias": "sentinel2", + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "duration": "30d", + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "0d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", "sort_by": "cloud_cover", + "time_offset": "30d", "use_rtree_index": false }, "type": "raster" }, - "sentinel2.1": { + "sentinel2_c": { + "alias": "sentinel2", "band_sets": [ { "bands": [ @@ -101,9 +153,21 @@ "zoom_offset": -2 } ], + "data_source": { + "duration": "30d", + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "0d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "sort_by": "cloud_cover", + "time_offset": "0d", + "use_rtree_index": false + }, "type": "raster" }, - "sentinel2.2": { + "sentinel2_d": { + "alias": "sentinel2", "band_sets": [ { "bands": [ @@ -136,6 +200,17 @@ "zoom_offset": -2 } ], + "data_source": { + "duration": "30d", + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "0d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "sort_by": "cloud_cover", + "time_offset": "-30d", + "use_rtree_index": false + }, "type": "raster" } }, diff --git a/data/satlas/marine_infra/config.yaml b/data/satlas/marine_infra/config.yaml index 598f6752..3092de7e 100644 --- a/data/satlas/marine_infra/config.yaml +++ b/data/satlas/marine_infra/config.yaml @@ -39,23 +39,29 @@ model: data: class_path: rslearn.train.data_module.RslearnDataModule init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/ + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ inputs: image1: data_type: "raster" - layers: ["sentinel2"] + layers: ["sentinel2_a"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 image2: data_type: "raster" - layers: ["sentinel2.1"] + layers: ["sentinel2_b"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 image3: data_type: "raster" - layers: ["sentinel2.2"] + layers: ["sentinel2_c"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2_d"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 @@ -92,7 +98,7 @@ data: input_mapping: detect: targets: "targets" - batch_size: 8 + batch_size: 4 num_workers: 32 default_config: transforms: @@ -102,20 +108,21 @@ data: std: 3000 valid_range: [0, 1] bands: [0, 1, 2] - selectors: ["image1", "image2", "image3"] + selectors: ["image1", "image2", "image3", "image4"] - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 std: 8160 valid_range: [0, 1] bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3"] + selectors: ["image1", "image2", "image3", "image4"] - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: selections: image1: [] image2: [] image3: [] + image4: [] output_selector: image - class_path: rslp.transforms.mask.Mask train_config: @@ -127,20 +134,21 @@ data: std: 3000 valid_range: [0, 1] bands: [0, 1, 2] - selectors: ["image1", "image2", "image3"] + selectors: ["image1", "image2", "image3", "image4"] - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 std: 8160 valid_range: [0, 1] bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3"] + selectors: ["image1", "image2", "image3", "image4"] - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: selections: image1: [] image2: [] image3: [] + image4: [] output_selector: image - class_path: rslp.transforms.mask.Mask - class_path: rslearn.train.transforms.flip.Flip @@ -165,20 +173,21 @@ data: std: 3000 valid_range: [0, 1] bands: [0, 1, 2] - selectors: ["image1", "image2", "image3"] + selectors: ["image1", "image2", "image3", "image4"] - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 std: 8160 valid_range: [0, 1] bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3"] + selectors: ["image1", "image2", "image3", "image4"] - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: selections: image1: [] image2: [] image3: [] + image4: [] output_selector: image groups: ["predict"] load_all_patches: true @@ -192,7 +201,7 @@ trainer: logging_interval: "epoch" - class_path: rslearn.train.prediction_writer.RslearnWriter init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/ + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ output_layer: output selector: ["detect"] - class_path: lightning.pytorch.callbacks.ModelCheckpoint @@ -206,4 +215,4 @@ trainer: module_selector: ["model", "encoder", 0, "encoder", "model"] unfreeze_at_epoch: 2 rslp_project: satlas_marine_infra -rslp_experiment: data_20241002_satlaspretrainold_patch512_satlasnorm_3image_01 +rslp_experiment: data_20241030_satlaspretrainold_patch512_00 diff --git a/rslp/satlas/README.md b/rslp/satlas/README.md new file mode 100644 index 00000000..2c5c881f --- /dev/null +++ b/rslp/satlas/README.md @@ -0,0 +1,7 @@ +## Marine Infrastructure + +Inference: + + PYTHONPATH=~/rslearn:. python -m rslp.main satlas launch MARINE_INFRA '["2024-01-01T00:00:00+00:00", "2024-04-01T00:00:00+00:00"]' gs://rslearn-eai/projects/satlas/marine_infra/version-20241030/2024-01/ + +Post-processing: diff --git a/rslp/satlas/__init__.py b/rslp/satlas/__init__.py index 756cb533..ef31cdb0 100644 --- a/rslp/satlas/__init__.py +++ b/rslp/satlas/__init__.py @@ -7,8 +7,13 @@ - Tree cover """ -from .predict_pipeline import predict_pipeline +from .job_launcher import launch_jobs +from .postprocess import postprocess_points +from .predict_pipeline import predict_multi, predict_pipeline workflows = { "predict": predict_pipeline, + "predict_multi": predict_multi, + "launch": launch_jobs, + "postprocess_points": postprocess_points, } diff --git a/rslp/satlas/job_launcher.py b/rslp/satlas/job_launcher.py index 9af927c6..3a61a4db 100644 --- a/rslp/satlas/job_launcher.py +++ b/rslp/satlas/job_launcher.py @@ -2,12 +2,10 @@ import json import multiprocessing -import os import random import uuid from datetime import datetime -import rslearn.utils.get_utm_ups_crs import shapely import tqdm from beaker import ( @@ -23,13 +21,13 @@ from rasterio.crs import CRS from rslearn.const import WGS84_PROJECTION from rslearn.utils.geometry import PixelBounds, Projection, STGeometry +from rslearn.utils.get_utm_ups_crs import get_proj_bounds -from .predict_pipeline import Application, get_output_fname +from rslp.launch_beaker import BUDGET, DEFAULT_WORKSPACE, IMAGE_NAME, get_base_env_vars -WORKSPACE = "ai2/earth-systems" -BUDGET = "ai2/d5" -IMAGE_NAME = "favyen/rslearn" -TILE_SIZE = 16384 +from .predict_pipeline import Application, PredictTaskArgs, get_output_fname + +TILE_SIZE = 32768 RESOLUTION = 10 @@ -60,36 +58,40 @@ def __init__( self.out_path = out_path -def launch_job(task: Task) -> None: +def launch_job(batch: list[Task]) -> None: """Launch job for this task. Args: - task: the Task object for which to create a job. + batch: list of Task objects for which to create a job. """ - beaker = Beaker.from_env(default_workspace=WORKSPACE) + beaker = Beaker.from_env(default_workspace=DEFAULT_WORKSPACE) + + # Convert tasks to PredictTask. + # These just set projection/bounds/time range, so the application and output path + # come from the first task. + predict_tasks = [] + for task in batch: + predict_tasks.append( + PredictTaskArgs( + projection_json=task.projection.serialize(), + bounds=task.bounds, + time_range=task.time_range, + ) + ) with beaker.session(): - env_vars = [ - EnvVar( - name="GOOGLE_APPLICATION_CREDENTIALS", # nosec - value="/etc/credentials/gcp_credentials.json", # nosec - ), + env_vars = get_base_env_vars(use_weka_prefix=False) + env_vars.append( EnvVar( - name="GCLOUD_PROJECT", # nosec - value="skylight-proto-1", # nosec - ), - EnvVar( - name="RSLP_BUCKET", - value=os.environ["RSLP_BUCKET"], - ), - EnvVar( - name="MKL_THREADING_LAYER", - value="GNU", - ), - ] + name="RSLEARN_LOGLEVEL", + value="DEBUG", + ) + ) + # Name the job based on the first task. + task = batch[0] experiment_name = ( - f"satlas_{task.application.value}_{str(task.projection.crs)}_" + f"satlas_{task.application.value}_{task.projection.crs.to_epsg()}_" + f"{task.bounds[0]}_{task.bounds[1]}" ) @@ -101,15 +103,13 @@ def launch_job(task: Task) -> None: command=["python", "-m", "rslp.main"], arguments=[ "satlas", - "predict", - task.application, - json.dumps(task.projection.serialize()), - json.dumps(task.bounds), + "predict_multi", + task.application.value.upper(), + task.out_path, + f"/data/favyenb/marine_infra/scratch/{experiment_name}/", json.dumps( - [task.time_range[0].isoformat(), task.time_range[1].isoformat()] + [predict_task.serialize() for predict_task in predict_tasks] ), - task.out_path, - "/tmp/scratch/", ], constraints=Constraints( cluster=[ @@ -120,6 +120,7 @@ def launch_job(task: Task) -> None: "ai2/general-cirrascale", "ai2/prior-cirrascale", "ai2/prior-elanding", + "ai2/augusta-google-1", ] ), preemptible=True, @@ -128,9 +129,13 @@ def launch_job(task: Task) -> None: source=DataSource(secret="RSLEARN_GCP_CREDENTIALS"), # nosec mount_path="/etc/credentials/gcp_credentials.json", # nosec ), + DataMount( + source=DataSource(host_path="/data"), # nosec + mount_path="/data", # nosec + ), ], env_vars=env_vars, - resources=TaskResources(gpu_count=1), + resources=TaskResources(gpu_count=1, shared_memory="256GiB"), ) unique_id = str(uuid.uuid4())[0:8] beaker.experiment.create(experiment_name + "_" + unique_id, spec) @@ -158,8 +163,9 @@ def launch_jobs( time_range: tuple[datetime, datetime], out_path: str, epsg_code: int | None = None, - wgs84_bounds: PixelBounds | None = None, + wgs84_bounds: tuple[float, float, float, float] | None = None, count: int | None = None, + batch_size: int = 1, ) -> None: """Launch Beaker jobs for Satlas prediction. @@ -171,40 +177,68 @@ def launch_jobs( run in all UTM zones. wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. count: only run up to this many tasks. + batch_size: how many tasks to run in each Beaker job. """ # Generate tasks. if epsg_code: utm_zones = [CRS.from_epsg(epsg_code)] else: + utm_zones = [] for epsg_code in range(32601, 32661): utm_zones.append(CRS.from_epsg(epsg_code)) for epsg_code in range(32701, 32761): utm_zones.append(CRS.from_epsg(epsg_code)) tasks: list[Task] = [] - for utm_zone in utm_zones: - zone_bounds = rslearn.utils.get_utm_ups_crs.get_proj_bounds(utm_zone) + for utm_zone in tqdm.tqdm(utm_zones, desc="Enumerating tasks across UTM zones"): + # get_proj_bounds returns bounds in CRS units so we need to convert to pixel + # units. + crs_bbox = STGeometry( + Projection(utm_zone, 1, 1), + shapely.box(*get_proj_bounds(utm_zone)), + None, + ) projection = Projection(utm_zone, RESOLUTION, -RESOLUTION) - for col in range(zone_bounds[0], zone_bounds[2], TILE_SIZE): - for row in range(zone_bounds[1], zone_bounds[3], TILE_SIZE): - if wgs84_bounds is not None: - # Check if the longitude/latitude of this task is in wgs84_bounds. - src_geom = STGeometry(projection, shapely.Point(col, row), None) - wgs84_point = src_geom.to_projection(WGS84_PROJECTION).shp - if wgs84_point.x < wgs84_bounds[0]: + pixel_bbox = crs_bbox.to_projection(projection) + zone_bounds = tuple(int(value) for value in pixel_bbox.shp.bounds) + + user_bounds_in_proj: PixelBounds | None = None + if wgs84_bounds is not None: + dst_geom = STGeometry( + WGS84_PROJECTION, shapely.box(*wgs84_bounds), None + ).to_projection(projection) + user_bounds_in_proj = ( + int(dst_geom.shp.bounds[0]), + int(dst_geom.shp.bounds[1]), + int(dst_geom.shp.bounds[2]), + int(dst_geom.shp.bounds[3]), + ) + + for col in range(zone_bounds[0] // TILE_SIZE, zone_bounds[2] // TILE_SIZE + 1): + for row in range( + zone_bounds[1] // TILE_SIZE, zone_bounds[3] // TILE_SIZE + 1 + ): + if user_bounds_in_proj is not None: + # Check if this task intersects the bounds specified by the user. + if (col + 1) * TILE_SIZE < user_bounds_in_proj[0]: continue - if wgs84_point.x >= wgs84_bounds[2]: + if col * TILE_SIZE >= user_bounds_in_proj[2]: continue - if wgs84_point.y < wgs84_bounds[1]: + if (row + 1) * TILE_SIZE < user_bounds_in_proj[1]: continue - if wgs84_point.y >= wgs84_bounds[3]: + if row * TILE_SIZE >= user_bounds_in_proj[3]: continue tasks.append( Task( application=application, projection=projection, - bounds=(col, row, col + TILE_SIZE, row + TILE_SIZE), + bounds=( + col * TILE_SIZE, + row * TILE_SIZE, + (col + 1) * TILE_SIZE, + (row + 1) * TILE_SIZE, + ), time_range=time_range, out_path=out_path, ) @@ -233,5 +267,10 @@ def launch_jobs( print( f"Got {len(tasks)} total tasks, {len(pending_tasks)} pending, running {len(run_tasks)} of them" ) - for task in tqdm.tqdm(run_tasks, desc="Starting Beaker jobs"): - launch_job(task) + + batches = [] + for i in range(0, len(run_tasks), batch_size): + batches.append(run_tasks[i : i + batch_size]) + + for batch in tqdm.tqdm(batches, desc="Starting Beaker jobs"): + launch_job(batch) diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py index bd9e5ada..21c9040c 100644 --- a/rslp/satlas/predict_pipeline.py +++ b/rslp/satlas/predict_pipeline.py @@ -1,18 +1,29 @@ """Prediction pipeline for Satlas models.""" import json +import os import shutil from datetime import datetime from enum import Enum +from typing import Any +from rslearn.const import WGS84_PROJECTION from rslearn.dataset import Window from rslearn.utils.geometry import PixelBounds, Projection from upath import UPath +from rslp.log_utils import get_logger from rslp.utils.rslearn import materialize_dataset, run_model_predict -DATASET_CONFIG_FNAME = "convert_satlas_webmercator_to_rslearn/{application}/config.json" -MODEL_CONFIG_FNAME = "convert_satlas_webmercator_to_rslearn/{application}/config.yaml" +DATASET_CONFIG_FNAME = "data/satlas/{application}/config.json" +MODEL_CONFIG_FNAME = "data/satlas/{application}/config.yaml" +SENTINEL2_LAYER = "sentinel2" +PATCH_SIZE = 2048 + +# Layers not to use when seeing which patches are valid. +VALIDITY_EXCLUDE_LAYERS = ["mask", "output", "label"] + +logger = get_logger(__name__) class Application(Enum): @@ -75,8 +86,8 @@ def predict_pipeline( resolution. bounds: pixel coordinates within the projection on which to compute outputs. time_range: time range to apply model on. - out_path: where to write the outputs. It will either be a GeoTIFF or GeoJSON, - named based on the bounds. + out_path: directory to write the outputs. It will either be a GeoTIFF or + GeoJSON, named based on the bounds. scratch_path: where to store the dataset. """ dataset_config_fname = DATASET_CONFIG_FNAME.format(application=application.value) @@ -91,26 +102,95 @@ def predict_pipeline( # Initialize an rslearn dataset. ds_path = UPath(scratch_path) - ds_path.mkdir(parents=True, exist_ok=True) + ds_path.mkdir(parents=True) with open(dataset_config_fname) as f: ds_cfg = json.load(f) with (ds_path / "config.json").open("w") as f: json.dump(ds_cfg, f) - # Create a window corresponding to the specified projection and bounds. + # Create windows corresponding to the specified projection and bounds. + # Each window is PATCH_SIZE x PATCH_SIZE, we create multiple of smaller patch size + # than the bounds instead of one big window for better parallelism and memory + # usage. + # It also helps with creating the mosaic -- depending on the dataset configuration, + # if there are portions of large window that are not fully covered by scenes, then + # only one mosaic layer would be created. (TODO: actually that seems like an issue + # with the match_candidate_items_to_window logic. Maybe we should build all the + # mosaics simultaneously instead of discarding scenes that don't match with the + # current mosaic.) + # Note that bounds must be multiple of patch size. + for value in bounds: + assert value % PATCH_SIZE == 0 group = "predict" - window_path = ds_path / "windows" / group / "default" - window = Window( - path=window_path, - group=group, - name="default", - projection=projection, - bounds=bounds, - time_range=time_range, - ) - window.save() - - # Populate the window. + tile_to_window = {} + for tile_col in range(bounds[0] // PATCH_SIZE, bounds[2] // PATCH_SIZE): + for tile_row in range(bounds[1] // PATCH_SIZE, bounds[3] // PATCH_SIZE): + window_name = f"{tile_col}_{tile_row}" + window_bounds = ( + tile_col * PATCH_SIZE, + tile_row * PATCH_SIZE, + (tile_col + 1) * PATCH_SIZE, + (tile_row + 1) * PATCH_SIZE, + ) + window_path = ds_path / "windows" / group / window_name + window = Window( + path=window_path, + group=group, + name=window_name, + projection=projection, + bounds=window_bounds, + time_range=time_range, + ) + + # Skip if the window is too close to 0 longitude. + # Or if it crosses it. + epsilon = 1e-4 + wgs84_geom = window.get_geometry().to_projection(WGS84_PROJECTION) + wgs84_bounds = wgs84_geom.shp.bounds + if wgs84_bounds[0] <= -180 + epsilon or wgs84_bounds[2] >= 180 - epsilon: + logger.debug( + "skipping window at column %d row %d because it is out of bounds (wgs84_bounds=%s)", + tile_col, + tile_row, + wgs84_bounds, + ) + continue + if wgs84_bounds[0] < -90 and wgs84_bounds[2] > 90: + logger.debug( + "skipping window at column %d row %d because it seems to cross 0 longitude (wgs84_bounds=%s)", + tile_col, + tile_row, + wgs84_bounds, + ) + continue + + window.save() + tile_to_window[(tile_col, tile_row)] = window + + # Create the cache that will be needed to run "prepare" step in parallel. + """dataset = Dataset(ds_path) + needed_cell_years = set() + wgs84_geometry = window.get_geometry().to_projection(WGS84_PROJECTION) + for cell_id in rslearn.utils.mgrs.for_each_cell(wgs84_geometry.shp.bounds): + for year in range( + wgs84_geometry.time_range[0].year, + wgs84_geometry.time_range[1].year + 1, + ): + needed_cell_years.add((cell_id, year)) + cache_jobs = [] + for cell_id, year in needed_cell_years: + cache_jobs.append(dict( + cell=cell_id, + year=year, + dataset=dataset, + )) + p = multiprocessing.Pool(32) + outputs = star_imap_unordered(p, cache_cell, cache_jobs) + for _ in tqdm.tqdm(outputs, total=len(cache_jobs), desc="Caching Sentinel-2 metadata"): + pass + p.close()""" + + # Populate the windows. print("materialize dataset") materialize_dataset(ds_path, group=group) @@ -118,10 +198,169 @@ def predict_pipeline( run_model_predict(model_config_fname, ds_path) if APP_IS_RASTER[application]: - src_fname = window_path / "layers" / "output" / "output" / "geotiff.tif" + raise NotImplementedError + """src_fname = window_path / "layers" / "output" / "output" / "geotiff.tif" + + with src_fname.open("rb") as src: + with out_fname.open("wb") as dst: + shutil.copyfileobj(src, dst)""" + else: - src_fname = window_path / "layers" / "output" / "data.geojson" + # Merge the features across the windows. + # Here we also add valid patches attribute indicating which windows (patches) + # were non-zero. This is used to distinguish a point not being detected because + # it wasn't there vs not being detected just because there was no image + # available there. + fc = None + valid_patches = [] + for window in tile_to_window.values(): + window_output_fname = window.path / "layers" / "output" / "data.geojson" + + if not window_output_fname.exists(): + continue + + with window_output_fname.open() as f: + cur_fc = json.load(f) + + if fc is None: + fc = cur_fc + else: + fc["features"].extend(cur_fc["features"]) + + valid_patches.append( + (window.bounds[0] // PATCH_SIZE, window.bounds[1] // PATCH_SIZE) + ) + + if fc is None: + # So there was no image here. + # We still want to write an empty GeoJSON so the job is marked completed. + fc = { + "type": "FeatureCollection", + "features": [], + } + + """ + # Add a list specifying which patches are valid vs invalid to the GeoJSON. + # Valid means that none of the input layers are completely zero at the patch. + # This is so that when we smooth the predictions over time, we can distinguish + # a point not being detected because it wasn't there vs not being detected just + # because there was no image available there. + check_images = window_path.glob("layers/*/B02_B03_B04_B08/geotiff.tif") + valid_patches = set() + for check_image in check_images: + path_parts = check_image.path.split("/") + if path_parts[-3] in VALIDITY_EXCLUDE_LAYERS: + continue + + with check_image.open("rb") as f: + with rasterio.open(f) as raster: + valid_mask = raster.read().max(axis=0) > 0 + + for tile_col in range(bounds[0] // PATCH_SIZE, bounds[2] // PATCH_SIZE): + for tile_row in range(bounds[1] // PATCH_SIZE, bounds[3] // PATCH_SIZE): + cur_patch_id = (tile_col, tile_row) + cur_offset = (tile_col * PATCH_SIZE, tile_row * PATCH_SIZE) + + if cur_patch_id in valid_patches: + continue + + # Read from the window that contains this patch. + window = tile_to_window[cur_patch_id] + + + patch_valid = np.zeros((VALIDITY_PATCH_SIZE, VALIDITY_PATCH_SIZE)) + copy_spatial_array(valid_mask, patch_valid, bounds[0:2], cur_offset) + if valid_mask.max() is False: + continue + + valid_patches.add(cur_patch_id) + """ - with src_fname.open("rb") as src: - with out_fname.open("wb") as dst: - shutil.copyfileobj(src, dst) + if "properties" not in fc: + fc["properties"] = {} + fc["properties"]["valid_patches"] = { + str(projection.crs): list(valid_patches), + } + + # The object detector predicts bounding boxes but we want to make all features + # just points. + for feat in fc["features"]: + assert feat["geometry"]["type"] == "Polygon" + coords = feat["geometry"]["coordinates"][0] + xs = [coord[0] for coord in coords] + ys = [coord[1] for coord in coords] + feat["geometry"] = { + "type": "Point", + "coordinates": [ + (min(xs) + max(xs)) / 2, + (min(ys) + max(ys)) / 2, + ], + } + + with out_fname.open("w") as f: + json.dump(fc, f) + + +class PredictTaskArgs: + """Represents one prediction task among a set that shares application and paths.""" + + def __init__( + self, + projection_json: dict[str, Any], + bounds: PixelBounds, + time_range: tuple[datetime, datetime], + ): + """Create a new PredictTaskArgs. + + Args: + projection_json: serialized projection. + bounds: the bounds of this task. + time_range: the time range of this task. + """ + self.projection_json = projection_json + self.bounds = bounds + self.time_range = time_range + + def serialize(self) -> dict[str, Any]: + """Serialize the task to a dictionary. + + Returns: + JSON-encodable dictionary. + """ + return dict( + projection_json=self.projection_json, + bounds=json.dumps(self.bounds), + time_range=json.dumps( + (self.time_range[0].isoformat(), self.time_range[1].isoformat()) + ), + ) + + +def predict_multi( + application: Application, + out_path: str, + scratch_path: str, + tasks: list[PredictTaskArgs], +) -> None: + """Run multiple prediction tasks. + + Args: + application: the application. + out_path: directory to write outputs. + scratch_path: local directory to use for scratch space. + tasks: list of tasks to execute. + """ + if os.path.exists(scratch_path): + shutil.rmtree(scratch_path) + + for task in tasks: + predict_pipeline( + application=application, + projection_json=json.dumps(task.projection_json), + bounds=task.bounds, + time_range=task.time_range, + out_path=out_path, + scratch_path=scratch_path, + ) + if os.path.exists(scratch_path): + shutil.rmtree(scratch_path) diff --git a/rslp/satlas/scripts/smooth_point_labels_viterbi.go b/rslp/satlas/scripts/smooth_point_labels_viterbi.go new file mode 100644 index 00000000..696365a0 --- /dev/null +++ b/rslp/satlas/scripts/smooth_point_labels_viterbi.go @@ -0,0 +1,419 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "log" + "math" + "os" + "strconv" + "strings" + + "github.com/mitroadmaps/gomapinfer/common" +) + +const FUTURE_LABEL = "2030-01" + +type Tile struct { + Projection string + Column int + Row int +} + +type Point struct { + Geometry struct { + Type string `json:"type"` + Coordinates [2]float64 `json:"coordinates"` + } + label string + Properties struct { + Category string `json:"category"` + Score float64 `json:"score"` + Projection string `json:"projection,omitempty"` + Column int `json:"column,omitempty"` + Row int `json:"row,omitempty"` + Start string `json:"start,omitempty"` + End string `json:"end,omitempty"` + } `json:"properties"` +} + +type PointData struct { + Type string `json:"type"` + Features []Point `json:"features"` + Properties struct { + ValidPatches map[string][][2]int `json:"valid_patches"` + } `json:"properties"` +} + +type Group []Point + +func (g Group) Center() [2]int { + var sum [2]int + for _, p := range g { + sum[0] += p.Properties.Column + sum[1] += p.Properties.Row + } + return [2]int{ + sum[0] / len(g), + sum[1] / len(g), + } +} + +func decrementLabel(label string) string { + parts := strings.Split(label, "-") + year, _ := strconv.Atoi(parts[0]) + month, _ := strconv.Atoi(parts[1]) + month -= 1 + if month == 0 { + year -= 1 + month = 12 + } + return fmt.Sprintf("%04d-%02d", year, month) +} + +// Returns the end label for idx. +// If idx >= len(labels), we return a label far in the future. +// Otherwise, we just return labels[idx]. +func getEndLabel(labels []string, idx int) string { + if idx >= len(labels) { + return FUTURE_LABEL + } + return labels[idx] +} + +// Grid size which must be larger than the maximum expected distance threshold. +// This is currently in zoom 13 pixel coordinates so each unit is about 10 m. +const GridSize float64 = 256 + +// Factor to divide by to convert meters to point units. +const MetersPerPixel = 10 + +func main() { + labels := flag.String("labels", "", "Comma-separated list of labels") + pointFname := flag.String("fname", "", "Point filename with LABEL placeholder like in/LABEL.geojson") + outFname := flag.String("out", "", "Output filename with LABEL placeholder like out/LABEL.geojson") + histFname := flag.String("hist", "", "Merged history output filename") + distanceThreshold := flag.Float64("max_dist", 200, "Matching distance threshold in meters") + numThreads := flag.Int("threads", 32, "Number of threads") + flag.Parse() + + labelList := strings.Split(*labels, ",") + + // Read points beginning with the most recent set + // (which is likely the one that covers the most points). + var groups []Group + // Keep track of map from tiles to labels in which the tile is valid. + tileLabelValidity := make(map[Tile][]string) + for labelIdx := len(labelList) - 1; labelIdx >= 0; labelIdx-- { + label := labelList[labelIdx] + + fname := strings.ReplaceAll(*pointFname, "LABEL", label) + if _, err := os.Stat(fname); os.IsNotExist(err) { + continue + } + bytes, err := ioutil.ReadFile(fname) + if err != nil { + panic(err) + } + + var data PointData + if err := json.Unmarshal(bytes, &data); err != nil { + panic(err) + } + + // Build grid index from the current features. + curPoints := data.Features + for idx := range curPoints { + curPoints[idx].label = label + } + gridIndexes := make(map[string]*common.GridIndex) + for idx, point := range curPoints { + projection := point.Properties.Projection + col := float64(point.Properties.Column) + row := float64(point.Properties.Row) + if gridIndexes[projection] == nil { + gridIndexes[projection] = common.NewGridIndex(GridSize) + } + gridIndexes[projection].Insert(idx, common.Rectangle{ + Min: common.Point{col, row}, + Max: common.Point{col, row}, + }) + } + + log.Printf("matching %d groups with %d features at %v", len(groups), len(curPoints), label) + + // Match existing groups to the new points. + matchedIndices := make(map[int]bool) + for groupIdx, group := range groups { + projection := group[0].Properties.Projection + center := group.Center() + indices := gridIndexes[projection].Search(common.Rectangle{ + Min: common.Point{float64(center[0]) - GridSize, float64(center[1]) - GridSize}, + Max: common.Point{float64(center[0]) + GridSize, float64(center[1]) + GridSize}, + }) + var closestIdx int = -1 + var closestDistance float64 + for _, idx := range indices { + if matchedIndices[idx] { + continue + } + if group[0].Properties.Category != curPoints[idx].Properties.Category { + continue + } + + dx := center[0] - curPoints[idx].Properties.Column + dy := center[1] - curPoints[idx].Properties.Row + distance := math.Sqrt(float64(dx*dx + dy*dy)) + + if distance > *distanceThreshold/MetersPerPixel { + continue + } + if closestIdx == -1 || distance < closestDistance { + closestIdx = idx + closestDistance = distance + } + } + + if closestIdx == -1 { + continue + } + + matchedIndices[closestIdx] = true + groups[groupIdx] = append(groups[groupIdx], curPoints[closestIdx]) + } + + // Add unmatched points in the current time as new groups. + for idx, point := range curPoints { + if matchedIndices[idx] { + continue + } + groups = append(groups, Group{point}) + } + + // Also update valid tiles/labels. + for projection, patches := range data.Properties.ValidPatches { + for _, patch := range patches { + tile := Tile{ + Projection: projection, + Column: patch[0], + Row: patch[1], + } + tileLabelValidity[tile] = append(tileLabelValidity[tile], label) + } + } + } + + // Apply Viterbi algorithm in each group. + initialProbs := []float64{0.5, 0.5} + transitionProbs := [][]float64{ + {0.95, 0.05}, + {0.01, 0.99}, + } + emissionProbs := [][]float64{ + {0.8, 0.2}, + {0.2, 0.8}, + } + // Convert as observation history to a list of ranges when the state is non-zero. + applyViterbi := func(history []int) [][2]int { + probs := make([]float64, len(initialProbs)) + copy(probs, initialProbs) + var pointers [][]int + + // Forward pass. + for _, observation := range history { + newProbs := make([]float64, len(probs)) + curPointers := make([]int, len(probs)) + // For each new state, take max over probability resulting from different prev states. + for newState := range probs { + for prevState, prevProb := range probs { + prob := prevProb * transitionProbs[prevState][newState] * emissionProbs[newState][observation] + if prob > newProbs[newState] { + newProbs[newState] = prob + curPointers[newState] = prevState + } + } + } + probs = newProbs + pointers = append(pointers, curPointers) + } + + // Backward pass: compute max and then follow the pointers. + var finalState int + var bestProb float64 + for state, prob := range probs { + if prob < bestProb { + continue + } + bestProb = prob + finalState = state + } + reversedStates := []int{finalState} + curState := finalState + for i := len(pointers) - 1; i > 0; i-- { + curState = pointers[i][curState] + reversedStates = append(reversedStates, curState) + } + states := make([]int, len(reversedStates)) + for i := range states { + states[i] = reversedStates[len(states)-i-1] + } + + // Convert to ranges. + var ranges [][2]int + var startIdx int = -1 + for idx, state := range states { + if state == 0 && startIdx >= 0 { + // Object was active but no longer. + ranges = append(ranges, [2]int{startIdx, idx}) + startIdx = -1 + } + if state == 1 && startIdx == -1 { + startIdx = idx + } + } + // Add last range if any. + if startIdx != -1 { + ranges = append(ranges, [2]int{startIdx, len(states)}) + } + return ranges + } + + // Pass each group through Viterbi algorithm. + // This yields time ranges where a group was present in the world. + // Usually there should just be one time range associated with each group, + // but there could be multiple if there really was a gap. + // Anyway we then collect those ranges into output data. + var historyData PointData + outFeatures := make(map[string]*PointData) + log.Println("processing groups") + ch := make(chan Group) + type Rng struct { + Group Group + StartIdx int + EndIdx int + } + donech := make(chan []Rng) + for i := 0; i < *numThreads; i++ { + go func() { + var myRngs []Rng + for group := range ch { + // Create set of labels where the point is present. + labelSet := make(map[string]bool) + for _, point := range group { + labelSet[point.label] = true + } + + // Also create label set where the tile containing the point was valid. + // If tile is invalid at a label, it implies there was no satellite image data at that location/time. + validLabelSet := make(map[string]bool) + center := group.Center() + tile := Tile{ + Projection: group[0].Properties.Projection, + Column: center[0] / 512, + Row: center[1] / 512, + } + for _, label := range tileLabelValidity[tile] { + validLabelSet[label] = true + } + + // Now make history of observations for Viterbi algorithm. + // We only include timesteps where the tile was valid. + // We also create a map from observed timesteps to original timestep index. + var observations []int + var labelIdxMap []int + for labelIdx, label := range labelList { + if !validLabelSet[label] { + continue + } + labelIdxMap = append(labelIdxMap, labelIdx) + if labelSet[label] { + observations = append(observations, 1) + } else { + observations = append(observations, 0) + } + } + + ranges := applyViterbi(observations) + for _, rng := range ranges { + startIdx := labelIdxMap[rng[0]] + var endIdx int + if rng[1] == len(observations) { + endIdx = len(labelList) + } else { + endIdx = labelIdxMap[rng[1]] + } + myRngs = append(myRngs, Rng{ + Group: group, + StartIdx: startIdx, + EndIdx: endIdx, + }) + } + } + donech <- myRngs + }() + } + for _, group := range groups { + ch <- group + } + close(ch) + for i := 0; i < *numThreads; i++ { + curRngs := <-donech + for _, rng := range curRngs { + last := rng.Group[len(rng.Group)-1] + feat := Point{} + feat.Geometry = last.Geometry + feat.Properties.Category = last.Properties.Category + feat.Properties.Score = last.Properties.Score + + // Add the feature to the monthly outputs. + for labelIdx := rng.StartIdx; labelIdx < rng.EndIdx; labelIdx++ { + label := labelList[labelIdx] + if outFeatures[label] == nil { + outFeatures[label] = &PointData{} + } + outFeatures[label].Features = append(outFeatures[label].Features, feat) + } + + // Now set start and end label for this feature correctly. + // Along with the score (computed as average of the points in the group). + // And then add it to history. + feat.Properties.Start = labelList[rng.StartIdx] + feat.Properties.End = getEndLabel(labelList, rng.EndIdx) + + var scoreSum float64 = 0 + for _, p := range rng.Group { + scoreSum += p.Properties.Score + } + feat.Properties.Score = scoreSum / float64(len(rng.Group)) + + historyData.Features = append(historyData.Features, feat) + } + } + + log.Println("writing outputs") + + if *histFname != "" { + bytes, err := json.Marshal(historyData) + if err != nil { + panic(err) + } + if err := ioutil.WriteFile(*histFname, bytes, 0644); err != nil { + panic(err) + } + } + + if *outFname != "" { + for label, data := range outFeatures { + fname := strings.ReplaceAll(*outFname, "LABEL", label) + bytes, err := json.Marshal(data) + if err != nil { + panic(err) + } + if err := ioutil.WriteFile(fname, bytes, 0644); err != nil { + panic(err) + } + } + } +} diff --git a/rslp/satlas/train.py b/rslp/satlas/train.py index e6c905e6..21c7212a 100644 --- a/rslp/satlas/train.py +++ b/rslp/satlas/train.py @@ -38,12 +38,11 @@ def process_inputs( return {}, {} for feat in raw_inputs["targets"]: - if self.property_name not in feat["properties"]: + if self.property_name not in feat.properties: continue - properties = feat["properties"] - category = properties[self.property_name] + category = feat.properties[self.property_name] if category not in CATEGORY_MAPPING: continue - properties[self.property_name] = CATEGORY_MAPPING[category] + feat.properties[self.property_name] = CATEGORY_MAPPING[category] return super().process_inputs(raw_inputs, metadata, load_targets) From 4fb0a690bfa0950c05930a9f6de5ea208b5bd7a6 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Mon, 11 Nov 2024 14:44:05 -0800 Subject: [PATCH 17/58] gcp rtree index not working after august 2024 ... --- data/satlas/marine_infra/config.yaml | 1 + rslp/satlas/job_launcher.py | 6 +-- rslp/satlas/predict_pipeline.py | 73 +++++++++++++++++----------- rslp/utils/rslearn.py | 31 ++++++++++-- 4 files changed, 73 insertions(+), 38 deletions(-) diff --git a/data/satlas/marine_infra/config.yaml b/data/satlas/marine_infra/config.yaml index 3092de7e..f885dbdc 100644 --- a/data/satlas/marine_infra/config.yaml +++ b/data/satlas/marine_infra/config.yaml @@ -90,6 +90,7 @@ data: exclude_by_center: true enable_map_metric: true enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] skip_unknown_categories: true f1_metric_kwargs: cmp_mode: "distance" diff --git a/rslp/satlas/job_launcher.py b/rslp/satlas/job_launcher.py index 3a61a4db..a0ab0af7 100644 --- a/rslp/satlas/job_launcher.py +++ b/rslp/satlas/job_launcher.py @@ -106,7 +106,7 @@ def launch_job(batch: list[Task]) -> None: "predict_multi", task.application.value.upper(), task.out_path, - f"/data/favyenb/marine_infra/scratch/{experiment_name}/", + "/tmp/scratch/", json.dumps( [predict_task.serialize() for predict_task in predict_tasks] ), @@ -129,10 +129,6 @@ def launch_job(batch: list[Task]) -> None: source=DataSource(secret="RSLEARN_GCP_CREDENTIALS"), # nosec mount_path="/etc/credentials/gcp_credentials.json", # nosec ), - DataMount( - source=DataSource(host_path="/data"), # nosec - mount_path="/data", # nosec - ), ], env_vars=env_vars, resources=TaskResources(gpu_count=1, shared_memory="256GiB"), diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py index 21c9040c..db53e80b 100644 --- a/rslp/satlas/predict_pipeline.py +++ b/rslp/satlas/predict_pipeline.py @@ -8,6 +8,7 @@ from typing import Any from rslearn.const import WGS84_PROJECTION +from rslearn.data_sources.copernicus import load_sentinel2_tile_index from rslearn.dataset import Window from rslearn.utils.geometry import PixelBounds, Projection from upath import UPath @@ -20,6 +21,10 @@ SENTINEL2_LAYER = "sentinel2" PATCH_SIZE = 2048 +# Add padding to the time range specified by the user for prediction since some +# applications use images from up to this many days outside of that time range. +RTREE_TIME_PAD_DAYS = 30 + # Layers not to use when seeing which patches are valid. VALIDITY_EXCLUDE_LAYERS = ["mask", "output", "label"] @@ -97,7 +102,7 @@ def predict_pipeline( projection = Projection.deserialize(json.loads(projection_json)) out_fname = get_output_fname(application, out_path, projection, bounds) if out_fname.exists(): - print(f"output file {out_fname} already exists") + logger.info(f"output file {out_fname} already exists") return # Initialize an rslearn dataset. @@ -105,6 +110,27 @@ def predict_pipeline( ds_path.mkdir(parents=True) with open(dataset_config_fname) as f: ds_cfg = json.load(f) + + # Set the time range to use for the rtree. + # And also make sure the rtree will be cached based on the out_path. + index_cache_dir = ds_path / "index_cache_dir" + index_cache_dir.mkdir() + image_layer_names = [] + for layer_name, layer_cfg in ds_cfg["layers"].items(): + if "data_source" not in layer_cfg: + continue + layer_source_cfg = layer_cfg["data_source"] + if not layer_source_cfg["name"].endswith("gcp_public_data.Sentinel2"): + continue + layer_source_cfg["index_cache_dir"] = str(index_cache_dir) + # layer_source_cfg["rtree_cache_dir"] = str(UPath(out_path) / "index") + # layer_source_cfg["use_rtree_index"] = True + # layer_source_cfg["rtree_time_range"] = [ + # (time_range[0] - timedelta(days=RTREE_TIME_PAD_DAYS)).isoformat(), + # (time_range[1] + timedelta(days=RTREE_TIME_PAD_DAYS)).isoformat(), + # ] + image_layer_names.append(layer_name) + with (ds_path / "config.json").open("w") as f: json.dump(ds_cfg, f) @@ -167,43 +193,32 @@ def predict_pipeline( window.save() tile_to_window[(tile_col, tile_row)] = window - # Create the cache that will be needed to run "prepare" step in parallel. - """dataset = Dataset(ds_path) - needed_cell_years = set() - wgs84_geometry = window.get_geometry().to_projection(WGS84_PROJECTION) - for cell_id in rslearn.utils.mgrs.for_each_cell(wgs84_geometry.shp.bounds): - for year in range( - wgs84_geometry.time_range[0].year, - wgs84_geometry.time_range[1].year + 1, - ): - needed_cell_years.add((cell_id, year)) - cache_jobs = [] - for cell_id, year in needed_cell_years: - cache_jobs.append(dict( - cell=cell_id, - year=year, - dataset=dataset, - )) - p = multiprocessing.Pool(32) - outputs = star_imap_unordered(p, cache_cell, cache_jobs) - for _ in tqdm.tqdm(outputs, total=len(cache_jobs), desc="Caching Sentinel-2 metadata"): - pass - p.close()""" + # Before preparing, cache the Sentinel-2 tile index. + # This way it is only downloaded once here instead of many times during prepare. + # We could set use_initial_prepare_job=True in materialize_dataset call, but then + # it could take a minute or more longer than needed. + load_sentinel2_tile_index(index_cache_dir) # Populate the windows. - print("materialize dataset") - materialize_dataset(ds_path, group=group) - - # Run the model. - run_model_predict(model_config_fname, ds_path) + logger.info("materialize dataset") + materialize_dataset(ds_path, group=group, prepare_workers=128) + + # Run the model, only if at least one window has some data. + completed_fnames = ds_path.glob( + f"windows/{group}/*/layers/{image_layer_names[0]}/completed" + ) + if len(list(completed_fnames)) == 0: + logger.info("skipping prediction since no windows seem to have data") + else: + run_model_predict(model_config_fname, ds_path) if APP_IS_RASTER[application]: - raise NotImplementedError """src_fname = window_path / "layers" / "output" / "output" / "geotiff.tif" with src_fname.open("rb") as src: with out_fname.open("wb") as dst: shutil.copyfileobj(src, dst)""" + raise NotImplementedError else: # Merge the features across the windows. diff --git a/rslp/utils/rslearn.py b/rslp/utils/rslearn.py index 8e9168db..906ece09 100644 --- a/rslp/utils/rslearn.py +++ b/rslp/utils/rslearn.py @@ -15,7 +15,13 @@ def materialize_dataset( - ds_path: UPath, group: str | None = None, workers: int = 32 + ds_path: UPath, + group: str | None = None, + workers: int = 32, + initial_prepare_job: bool = False, + prepare_workers: int | None = None, + ingest_workers: int | None = None, + materialize_workers: int | None = None, ) -> None: """Materialize the specified dataset by running prepare/ingest/materialize. @@ -23,25 +29,42 @@ def materialize_dataset( ds_path: the dataset root. group: limit dataset actions to this group. workers: number of workers to use. + initial_prepare_job: set True if initial job during prepare is needed, e.g. if + the data source creates an index first. + prepare_workers: use this many workers for prepare stage (overrides workers + argument) + ingest_workers: use this many workers for ingest stage (overrides workers + argument) + materialize_workers: use this many workers for materialize stage (overrides + workers argument) """ dataset = Dataset(ds_path) + + if prepare_workers is None: + prepare_workers = workers + if ingest_workers is None: + ingest_workers = workers + if materialize_workers is None: + materialize_workers = workers + apply_on_windows( PrepareHandler(force=False), dataset, - workers=workers, + workers=prepare_workers, group=group, + use_initial_job=initial_prepare_job, ) apply_on_windows( IngestHandler(), dataset, - workers=workers, + workers=ingest_workers, group=group, use_initial_job=False, ) apply_on_windows( MaterializeHandler(), dataset, - workers=workers, + workers=materialize_workers, group=group, use_initial_job=False, ) From 7fccfb35e6aa7f30e606be49d2d33db08c0705f0 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 5 Dec 2024 09:50:06 -0800 Subject: [PATCH 18/58] latest changes --- data/satlas/marine_infra/config.json | 184 ++-------- data/satlas/marine_infra/config.yaml | 8 +- rslp/common/__init__.py | 7 + rslp/common/worker.py | 121 +++++++ rslp/launch_beaker.py | 4 +- rslp/main.py | 25 +- rslp/satlas/__init__.py | 6 +- rslp/satlas/data_sources.py | 156 ++++++++ rslp/satlas/job_launcher.py | 48 ++- rslp/satlas/job_launcher_worker.py | 339 ++++++++++++++++++ rslp/satlas/postprocess.py | 236 ++++++++++++ rslp/satlas/predict_pipeline.py | 24 +- rslp/satlas/scripts/go.mod | 12 + rslp/satlas/scripts/go.sum | 43 +++ .../scripts/smooth_point_labels_viterbi.go | 59 +-- 15 files changed, 1062 insertions(+), 210 deletions(-) create mode 100644 rslp/common/__init__.py create mode 100644 rslp/common/worker.py create mode 100644 rslp/satlas/data_sources.py create mode 100644 rslp/satlas/job_launcher_worker.py create mode 100644 rslp/satlas/postprocess.py create mode 100644 rslp/satlas/scripts/go.mod create mode 100644 rslp/satlas/scripts/go.sum diff --git a/data/satlas/marine_infra/config.json b/data/satlas/marine_infra/config.json index c72ff59a..bbb765a5 100644 --- a/data/satlas/marine_infra/config.json +++ b/data/satlas/marine_infra/config.json @@ -25,8 +25,7 @@ }, "type": "vector" }, - "sentinel2_a": { - "alias": "sentinel2", + "sentinel2": { "band_sets": [ { "bands": [ @@ -35,101 +34,15 @@ "B04", "B08" ], - "dtype": "uint16" - }, - { - "bands": [ - "B05", - "B06", - "B07", - "B8A", - "B11", - "B12" - ], - "dtype": "uint16", - "zoom_offset": -1 - }, - { - "bands": [ - "B01", - "B09", - "B10" - ], "dtype": "uint16", - "zoom_offset": -2 - } - ], - "data_source": { - "duration": "30d", - "harmonize": true, - "index_cache_dir": "cache/sentinel2", - "max_time_delta": "0d", - "modality": "L1C", - "name": "rslearn.data_sources.gcp_public_data.Sentinel2", - "sort_by": "cloud_cover", - "time_offset": "60d", - "use_rtree_index": false - }, - "type": "raster" - }, - "sentinel2_b": { - "alias": "sentinel2", - "band_sets": [ - { - "bands": [ - "B02", - "B03", - "B04", - "B08" - ], - "dtype": "uint16" - }, - { - "bands": [ - "B05", - "B06", - "B07", - "B8A", - "B11", - "B12" - ], - "dtype": "uint16", - "zoom_offset": -1 - }, - { - "bands": [ - "B01", - "B09", - "B10" - ], - "dtype": "uint16", - "zoom_offset": -2 - } - ], - "data_source": { - "duration": "30d", - "harmonize": true, - "index_cache_dir": "cache/sentinel2", - "max_time_delta": "0d", - "modality": "L1C", - "name": "rslearn.data_sources.gcp_public_data.Sentinel2", - "sort_by": "cloud_cover", - "time_offset": "30d", - "use_rtree_index": false - }, - "type": "raster" - }, - "sentinel2_c": { - "alias": "sentinel2", - "band_sets": [ - { - "bands": [ - "B02", - "B03", - "B04", - "B08" - ], - "dtype": "uint16" + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + } }, { "bands": [ @@ -141,6 +54,14 @@ "B12" ], "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, "zoom_offset": -1 }, { @@ -150,72 +71,41 @@ "B10" ], "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, "zoom_offset": -2 } ], "data_source": { - "duration": "30d", "harmonize": true, "index_cache_dir": "cache/sentinel2", + "max_cloud_cover": 50, "max_time_delta": "0d", "modality": "L1C", - "name": "rslearn.data_sources.gcp_public_data.Sentinel2", - "sort_by": "cloud_cover", - "time_offset": "0d", - "use_rtree_index": false - }, - "type": "raster" - }, - "sentinel2_d": { - "alias": "sentinel2", - "band_sets": [ - { - "bands": [ - "B02", - "B03", - "B04", - "B08" - ], - "dtype": "uint16" - }, - { - "bands": [ - "B05", - "B06", - "B07", - "B8A", - "B11", - "B12" - ], - "dtype": "uint16", - "zoom_offset": -1 + "name": "rslp.satlas.data_sources.MonthlySentinel2", + "query_config": { + "max_matches": 4 }, - { - "bands": [ - "B01", - "B09", - "B10" - ], - "dtype": "uint16", - "zoom_offset": -2 - } - ], - "data_source": { - "duration": "30d", - "harmonize": true, - "index_cache_dir": "cache/sentinel2", - "max_time_delta": "0d", - "modality": "L1C", - "name": "rslearn.data_sources.gcp_public_data.Sentinel2", "sort_by": "cloud_cover", - "time_offset": "-30d", "use_rtree_index": false }, "type": "raster" } }, "tile_store": { - "name": "file", - "root_dir": "tiles" + "class_path": "rslearn.tile_stores.default.DefaultTileStore", + "init_args": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + } + } } } diff --git a/data/satlas/marine_infra/config.yaml b/data/satlas/marine_infra/config.yaml index f885dbdc..93d2f2f9 100644 --- a/data/satlas/marine_infra/config.yaml +++ b/data/satlas/marine_infra/config.yaml @@ -43,25 +43,25 @@ data: inputs: image1: data_type: "raster" - layers: ["sentinel2_a"] + layers: ["sentinel2"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 image2: data_type: "raster" - layers: ["sentinel2_b"] + layers: ["sentinel2.1"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 image3: data_type: "raster" - layers: ["sentinel2_c"] + layers: ["sentinel2.2"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 image4: data_type: "raster" - layers: ["sentinel2_d"] + layers: ["sentinel2.3"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true dtype: FLOAT32 diff --git a/rslp/common/__init__.py b/rslp/common/__init__.py new file mode 100644 index 00000000..6cbbd859 --- /dev/null +++ b/rslp/common/__init__.py @@ -0,0 +1,7 @@ +"""Pipelines common across projects.""" + +from .worker import worker_pipeline + +workflows = { + "worker": worker_pipeline, +} diff --git a/rslp/common/worker.py b/rslp/common/worker.py new file mode 100644 index 00000000..dbb36016 --- /dev/null +++ b/rslp/common/worker.py @@ -0,0 +1,121 @@ +"""Worker to process jobs in a list of jobs.""" + +import json +import random +from datetime import datetime, timedelta, timezone + +from google.api_core.exceptions import PreconditionFailed +from google.cloud import storage +from upath import UPath + +from rslp.log_utils import get_logger +from rslp.main import run_workflow + +logger = get_logger(__name__) + +# Maximum expected duration of a job in hours. We use this to limit how long we care +# about a pending claim that hasn't completed yet. +MAX_JOB_HOURS = 4 + + +def _get_pending_jobs( + jobs: list[list[str]], claim_bucket: storage.Bucket, claim_dir: str +) -> list[int]: + """Get the indices of jobs that haven't been claimed yet. + + Args: + jobs: the full list of jobs. + claim_bucket: bucket where files indicating completed jobs are written. + claim_dir: path within bucket. + """ + claimed = set() + # Pending claims are only valid for a few hours. + for blob in claim_bucket.list_blobs(prefix=f"{claim_dir}pending/"): + if datetime.now(timezone.utc) - blob.time_created > timedelta( + hours=MAX_JOB_HOURS + ): + # This is a stale pending claim (the job may have completed, but if so we + # will see its completed blob below). + continue + claimed.add(int(blob.name.split("/")[-1])) + # While completed files indicate that the job is done permanently. + for blob in claim_bucket.list_blobs(prefix=f"{claim_dir}completed/"): + claimed.add(int(blob.name.split("/")[-1])) + + pending = [] + for idx in range(len(jobs)): + if idx in claimed: + continue + pending.append(idx) + + return pending + + +def worker_pipeline( + project: str, + workflow: str, + job_fname: str, + claim_bucket_name: str, + claim_dir: str, +) -> None: + """Start a worker to run the specified jobs. + + Args: + project: the project that the workflow to run is in. + workflow: the workflow to run. + job_fname: file containing the full list of jobs (arguments to the workflow + function) that need to be run. + claim_bucket_name: the GCS bucket to use for claiming jobs. + claim_dir: the path within claim_bucket_name to use for claiming jobs. + """ + job_upath = UPath(job_fname) + client = storage.Client() + claim_bucket = client.bucket(claim_bucket_name) + + with job_upath.open("r") as f: + jobs: list[list[str]] = json.load(f) + + # Get the currently pending jobs. + # Our strategy will be to sample a job and attempt to claim it. + # And then if the claim fails then we refresh the pending jobs. + # This works for up to ~10000 jobs. + pending = _get_pending_jobs(jobs, claim_bucket, claim_dir) + + while len(pending) > 0: + job_idx = random.choice(pending) + pending.remove(job_idx) + pending_blob = claim_bucket.blob(f"{claim_dir}pending/{job_idx}") + completed_blob = claim_bucket.blob(f"{claim_dir}completed/{job_idx}") + + # Determine the generation of pending_blob so we can create a newer one if + # applicable. If it doesn't exist, we use 0 so that it will throw error if the + # file exists at all (the actual generation should never be 0). + pending_blob_generation = 0 + is_pending = False + if pending_blob.exists(): + pending_blob.reload() + pending_blob_generation = pending_blob.generation + if datetime.now(timezone.utc) - pending_blob.time_created < timedelta( + hours=MAX_JOB_HOURS + ): + is_pending = True + + if is_pending or completed_blob.exists(): + pending = _get_pending_jobs(jobs, claim_bucket, claim_dir) + continue + + try: + # Use generation so that it throws error if generation doesn't match. + pending_blob.upload_from_string( + "", if_generation_match=pending_blob_generation + ) + except PreconditionFailed: + # This means another worker claimed the job in between when we confirmed + # the blob doesn't exist already and when we tried to claim it. In this + # case we just try again. + continue + + logger.info("claimed job %d and running it now", job_idx) + run_workflow(project, workflow, jobs[job_idx]) + + completed_blob.upload_from_string("") diff --git a/rslp/launch_beaker.py b/rslp/launch_beaker.py index 2bf75474..9c505aad 100644 --- a/rslp/launch_beaker.py +++ b/rslp/launch_beaker.py @@ -43,8 +43,8 @@ def get_base_env_vars(use_weka_prefix: bool = False) -> list[EnvVar]: value="/etc/credentials/gcp_credentials.json", # nosec ), EnvVar( - name="GCLOUD_PROJECT", # nosec - value="prior-satlas", # nosec + name="GOOGLE_CLOUD_PROJECT", # nosec + value="skylight-proto-1", # nosec ), EnvVar( name="WEKA_ACCESS_KEY_ID", # nosec diff --git a/rslp/main.py b/rslp/main.py index 89ee6247..1319b878 100644 --- a/rslp/main.py +++ b/rslp/main.py @@ -39,6 +39,19 @@ def datetime_deserializer(v: str) -> datetime: return datetime.fromisoformat(v) +def run_workflow(project: str, workflow: str, args: list[str]) -> None: + """Run the specified workflow. + + Args: + project: the project that the workflow is in. This is the name of the module. + workflow: the workflow name. + args: arguments to pass to jsonargparse for running the workflow function. + """ + module = importlib.import_module(f"rslp.{project}") + workflow_fn = module.workflows[workflow] + jsonargparse.CLI(workflow_fn, args=args) + + def main() -> None: """Main entrypoint function for rslp.""" dotenv.load_dotenv() @@ -46,19 +59,15 @@ def main() -> None: parser.add_argument("project", help="The project to execute a workflow for.") parser.add_argument("workflow", help="The name of the workflow.") args = parser.parse_args(args=sys.argv[1:3]) + run_workflow(args.project, args.workflow, sys.argv[3:]) + - module = importlib.import_module(f"rslp.{args.project}") - workflow_fn = module.workflows[args.workflow] +if __name__ == "__main__": + init_mp() # Setup jsonargparse. jsonargparse.typing.register_type( datetime, datetime_serializer, datetime_deserializer ) - # Parse arguments and run function. - jsonargparse.CLI(workflow_fn, args=sys.argv[3:]) - - -if __name__ == "__main__": - init_mp() main() diff --git a/rslp/satlas/__init__.py b/rslp/satlas/__init__.py index ef31cdb0..fc043dc1 100644 --- a/rslp/satlas/__init__.py +++ b/rslp/satlas/__init__.py @@ -7,13 +7,15 @@ - Tree cover """ -from .job_launcher import launch_jobs +from .job_launcher_worker import launch_workers, write_jobs, write_jobs_for_year_months from .postprocess import postprocess_points from .predict_pipeline import predict_multi, predict_pipeline workflows = { "predict": predict_pipeline, "predict_multi": predict_multi, - "launch": launch_jobs, + "write_jobs": write_jobs, + "write_jobs_for_year_months": write_jobs_for_year_months, + "launch_workers": launch_workers, "postprocess_points": postprocess_points, } diff --git a/rslp/satlas/data_sources.py b/rslp/satlas/data_sources.py new file mode 100644 index 00000000..fb8cdcd8 --- /dev/null +++ b/rslp/satlas/data_sources.py @@ -0,0 +1,156 @@ +"""Customized data sources for Satlas models.""" + +from datetime import timedelta +from typing import Any + +from rslearn.config import QueryConfig, RasterLayerConfig, SpaceMode +from rslearn.const import WGS84_PROJECTION +from rslearn.data_sources.data_source import DataSource, Item +from rslearn.data_sources.gcp_public_data import Sentinel2, Sentinel2Item +from rslearn.data_sources.utils import match_candidate_items_to_window +from rslearn.tile_stores import TileStore +from rslearn.utils.geometry import STGeometry +from upath import UPath + + +class MonthlySentinel2(DataSource): + """Sentinel2 data source where each match is a mosaic from a different month. + + It looks at the geometry time range, identifies matching items within each 30-day + period, and then picks the most recent {num_matches} months that have at least one + item. + + It also imposes a scene-level cloud cover limit. + """ + + def __init__( + self, + sentinel2: Sentinel2, + max_cloud_cover: float | None = None, + period_days: int = 30, + ): + """Create a new MonthlySentinel2. + + Args: + sentinel2: the Sentinel2 data source to wrap. + max_cloud_cover: cloud cover limit for scenes. + period_days: create mosaics for intervals of this many days within the + geometry time range. + """ + self.sentinel2 = sentinel2 + self.max_cloud_cover = max_cloud_cover + self.period_days = period_days + + @staticmethod + def from_config(config: RasterLayerConfig, ds_path: UPath) -> "MonthlySentinel2": + """Creates a new MonthlySentinel2 instance from a configuration dictionary.""" + sentinel2 = Sentinel2.from_config(config, ds_path) + kwargs = {} + d = config.data_source.config_dict + for k in ["max_cloud_cover", "period_days"]: + if k not in d: + continue + kwargs[k] = d[k] + return MonthlySentinel2(sentinel2, **kwargs) + + def deserialize_item(self, serialized_item: Any) -> Sentinel2Item: + """Deserializes an item from JSON-decoded data.""" + return self.sentinel2.deserialize_item(serialized_item) + + def get_items( + self, geometries: list[STGeometry], query_config: QueryConfig + ) -> list[list[list[Sentinel2Item]]]: + """Get a list of items in the data source intersecting the given geometries. + + Args: + geometries: the spatiotemporal geometries + query_config: the query configuration + + Returns: + List of groups of items that should be retrieved for each geometry. + """ + # This only makes sense for mosaic space mode. + assert query_config.space_mode == SpaceMode.MOSAIC + + # This part is the same as in base Sentinel2 class. + wgs84_geometries = [ + geometry.to_projection(WGS84_PROJECTION) for geometry in geometries + ] + + if self.sentinel2.rtree_index: + candidates = self.sentinel2._get_candidate_items_index(wgs84_geometries) + else: + candidates = self.sentinel2._get_candidate_items_direct(wgs84_geometries) + + groups = [] + + for geometry, item_list in zip(wgs84_geometries, candidates): + item_list.sort(key=lambda item: item.cloud_cover) + + # Apply cloud cover limit. + if self.max_cloud_cover is not None: + item_list = [ + item + for item in item_list + if item.cloud_cover <= self.max_cloud_cover + ] + + # Find matches across the periods. + # For each period, we create an STGeometry with modified time range + # matching the period, and obtain matching mosaic. + # We start from the end of the time range because we care more about recent + # periods and so we want to make sure that they align correctly with the + # end. + cur_groups: list[Item] = [] + period_end = geometry.time_range[1] + while ( + period_end > geometry.time_range[0] + and len(cur_groups) < query_config.max_matches + ): + period_time_range = ( + period_end - timedelta(days=self.period_days), + period_end, + ) + period_end -= timedelta(self.period_days) + period_geom = STGeometry( + geometry.projection, geometry.shp, period_time_range + ) + + # We modify the QueryConfig here since caller should be asking for + # multiple mosaics, but we just want one mosaic per period. + period_groups = match_candidate_items_to_window( + period_geom, + item_list, + QueryConfig(space_mode=SpaceMode.MOSAIC, max_matches=1), + ) + + # There should be zero on one groups depending on whether there were + # any items that matched. We keep the group if it is there. + if len(period_groups) == 0 or len(period_groups[0]) == 0: + # No matches for this period. + continue + cur_groups.append(period_groups[0]) + + # If there are not enough matching mosaics, then we eliminate all the + # matches since we aren't going to use this window then anyway. + if len(cur_groups) < query_config.max_matches: + cur_groups = [] + + groups.append(cur_groups) + + return groups + + def ingest( + self, + tile_store: TileStore, + items: list[Sentinel2Item], + geometries: list[list[STGeometry]], + ) -> None: + """Ingest items into the given tile store. + + Args: + tile_store: the tile store to ingest into + items: the items to ingest + geometries: a list of geometries needed for each item + """ + self.sentinel2.ingest(tile_store, items, geometries) diff --git a/rslp/satlas/job_launcher.py b/rslp/satlas/job_launcher.py index a0ab0af7..3e8fd5b8 100644 --- a/rslp/satlas/job_launcher.py +++ b/rslp/satlas/job_launcher.py @@ -4,7 +4,7 @@ import multiprocessing import random import uuid -from datetime import datetime +from datetime import datetime, timedelta, timezone import shapely import tqdm @@ -30,6 +30,12 @@ TILE_SIZE = 32768 RESOLUTION = 10 +# Days to add before a provided date. +DAYS_BEFORE = 120 + +# Days to add after a provided date. +DAYS_AFTER = 90 + class Task: """Represents a task that will correspond to one Beaker job.""" @@ -116,11 +122,9 @@ def launch_job(batch: list[Task]) -> None: "ai2/jupiter-cirrascale-2", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", - "ai2/pluto-cirrascale", - "ai2/general-cirrascale", - "ai2/prior-cirrascale", - "ai2/prior-elanding", "ai2/augusta-google-1", + # "ai2/prior-cirrascale", + # "ai2/prior-elanding", ] ), preemptible=True, @@ -270,3 +274,37 @@ def launch_jobs( for batch in tqdm.tqdm(batches, desc="Starting Beaker jobs"): launch_job(batch) + + +def launch_jobs_for_year_month( + year: int, + month: int, + application: Application, + out_path: str, + batch_size: int = 1, + count: int | None = None, +) -> None: + """Launch Satlas prediction jobs on Beaker for the given year and month. + + Args: + year: the year. + month: the month. + application: the application to run. + out_path: the output path with year and month placeholders. + batch_size: the batch size. + count: only run up to this many tasks. + """ + ts = datetime(year, month, 1, tzinfo=timezone.utc) + time_range = ( + ts - timedelta(days=DAYS_BEFORE), + ts + timedelta(days=DAYS_AFTER), + ) + cur_out_path = out_path.format(year=year, month=month) + print(f"launching jobs with time_range={time_range} and out_path={cur_out_path}") + launch_jobs( + application=application, + time_range=time_range, + out_path=cur_out_path, + batch_size=batch_size, + count=count, + ) diff --git a/rslp/satlas/job_launcher_worker.py b/rslp/satlas/job_launcher_worker.py new file mode 100644 index 00000000..2ded9d9e --- /dev/null +++ b/rslp/satlas/job_launcher_worker.py @@ -0,0 +1,339 @@ +"""Launch Satlas prediction jobs on Beaker.""" + +import json +import uuid +from datetime import datetime, timedelta, timezone + +import shapely +import tqdm +from beaker import ( + Beaker, + Constraints, + DataMount, + DataSource, + EnvVar, + ExperimentSpec, + Priority, + TaskResources, +) +from rasterio.crs import CRS +from rslearn.const import WGS84_PROJECTION +from rslearn.utils.geometry import PixelBounds, Projection, STGeometry +from rslearn.utils.get_utm_ups_crs import get_proj_bounds +from upath import UPath + +from rslp.launch_beaker import BUDGET, DEFAULT_WORKSPACE, IMAGE_NAME, get_base_env_vars +from rslp.log_utils import get_logger + +from .predict_pipeline import Application, PredictTaskArgs + +logger = get_logger(__name__) + +TILE_SIZE = 32768 +RESOLUTION = 10 + +# Days to add before a provided date. +DAYS_BEFORE = 120 + +# Days to add after a provided date. +DAYS_AFTER = 90 + + +class Task: + """Represents a task that processes one tile at one point in time.""" + + def __init__( + self, + application: Application, + projection: Projection, + bounds: PixelBounds, + time_range: tuple[datetime, datetime], + out_path: str, + ) -> None: + """Create a new Task. + + Args: + application: the application to run + projection: the projection of the tile + bounds: the bounds of the tile + time_range: the time range to process + out_path: where to write outputs + """ + self.application = application + self.projection = projection + self.bounds = bounds + self.time_range = time_range + self.out_path = out_path + + +class WorkerParams: + """Parameters that worker pipeline needs to know.""" + + def __init__(self, job_fname: str, claim_bucket_name: str, claim_dir: str) -> None: + """Create a new WorkerParams. + + Args: + job_fname: the filename containing list of jobs. + claim_bucket_name: the bucket where workers will claim jobs. + claim_dir: the path in the bucket to write claim files. + """ + self.job_fname = job_fname + self.claim_bucket_name = claim_bucket_name + self.claim_dir = claim_dir + + +def launch_worker(worker_params: WorkerParams) -> None: + """Launch a worker job. + + Args: + worker_params: the parameters to pass to the worker. + """ + beaker = Beaker.from_env(default_workspace=DEFAULT_WORKSPACE) + + with beaker.session(): + env_vars = get_base_env_vars(use_weka_prefix=False) + env_vars.append( + EnvVar( + name="RSLEARN_LOGLEVEL", + value="DEBUG", + ) + ) + + spec = ExperimentSpec.new( + budget=BUDGET, + description="worker", + beaker_image=IMAGE_NAME, + priority=Priority.low, + command=["python", "-m", "rslp.main"], + arguments=[ + "common", + "worker", + "satlas", + "predict_multi", + worker_params.job_fname, + worker_params.claim_bucket_name, + worker_params.claim_dir, + ], + constraints=Constraints( + cluster=[ + "ai2/jupiter-cirrascale-2", + "ai2/neptune-cirrascale", + "ai2/saturn-cirrascale", + "ai2/augusta-google-1", + ] + ), + preemptible=True, + datasets=[ + DataMount( + source=DataSource(secret="RSLEARN_GCP_CREDENTIALS"), # nosec + mount_path="/etc/credentials/gcp_credentials.json", # nosec + ), + ], + env_vars=env_vars, + resources=TaskResources(gpu_count=1, shared_memory="256GiB"), + ) + unique_id = str(uuid.uuid4())[0:8] + beaker.experiment.create(f"worker_{unique_id}", spec) + + +def get_jobs( + application: Application, + time_range: tuple[datetime, datetime], + out_path: str, + epsg_code: int | None = None, + wgs84_bounds: tuple[float, float, float, float] | None = None, + batch_size: int = 1, +) -> list[list[str]]: + """Get batches of tasks for Satlas prediction. + + Args: + application: which application to run. + time_range: the time range to run within. Must have timezone. + out_path: the output path. It should be specific to the time range. + epsg_code: limit tasks to this UTM zone (specified by its EPSG code), default + run in all UTM zones. + wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. + batch_size: how many tasks to run in each batch. + + Returns: + the list of worker tasks where each worker task + """ + # Generate tasks. + if epsg_code: + utm_zones = [CRS.from_epsg(epsg_code)] + else: + utm_zones = [] + for epsg_code in range(32601, 32661): + utm_zones.append(CRS.from_epsg(epsg_code)) + for epsg_code in range(32701, 32761): + utm_zones.append(CRS.from_epsg(epsg_code)) + + tasks: list[Task] = [] + for utm_zone in tqdm.tqdm(utm_zones, desc="Enumerating tasks across UTM zones"): + # get_proj_bounds returns bounds in CRS units so we need to convert to pixel + # units. + crs_bbox = STGeometry( + Projection(utm_zone, 1, 1), + shapely.box(*get_proj_bounds(utm_zone)), + None, + ) + projection = Projection(utm_zone, RESOLUTION, -RESOLUTION) + pixel_bbox = crs_bbox.to_projection(projection) + zone_bounds = tuple(int(value) for value in pixel_bbox.shp.bounds) + + user_bounds_in_proj: PixelBounds | None = None + if wgs84_bounds is not None: + dst_geom = STGeometry( + WGS84_PROJECTION, shapely.box(*wgs84_bounds), None + ).to_projection(projection) + user_bounds_in_proj = ( + int(dst_geom.shp.bounds[0]), + int(dst_geom.shp.bounds[1]), + int(dst_geom.shp.bounds[2]), + int(dst_geom.shp.bounds[3]), + ) + + for col in range(zone_bounds[0] // TILE_SIZE, zone_bounds[2] // TILE_SIZE + 1): + for row in range( + zone_bounds[1] // TILE_SIZE, zone_bounds[3] // TILE_SIZE + 1 + ): + if user_bounds_in_proj is not None: + # Check if this task intersects the bounds specified by the user. + if (col + 1) * TILE_SIZE < user_bounds_in_proj[0]: + continue + if col * TILE_SIZE >= user_bounds_in_proj[2]: + continue + if (row + 1) * TILE_SIZE < user_bounds_in_proj[1]: + continue + if row * TILE_SIZE >= user_bounds_in_proj[3]: + continue + + tasks.append( + Task( + application=application, + projection=projection, + bounds=( + col * TILE_SIZE, + row * TILE_SIZE, + (col + 1) * TILE_SIZE, + (row + 1) * TILE_SIZE, + ), + time_range=time_range, + out_path=out_path, + ) + ) + + print(f"Got {len(tasks)} total tasks") + + jobs = [] + for i in range(0, len(tasks), batch_size): + cur_tasks = tasks[i : i + batch_size] + + # Get list of PredictTaskArgs that we can serialize. + # These just specify the projection, time range, and bounds. + predict_tasks = [] + for task in cur_tasks: + predict_tasks.append( + PredictTaskArgs( + projection_json=task.projection.serialize(), + bounds=task.bounds, + time_range=task.time_range, + ) + ) + + cur_args = [ + application.value.upper(), + out_path, + "/tmp/scratch/", + json.dumps([predict_task.serialize() for predict_task in predict_tasks]), + ] + jobs.append(cur_args) + + return jobs + + +def write_jobs( + application: Application, + time_range: tuple[datetime, datetime], + out_path: str, + job_fname: str, + epsg_code: int | None = None, + wgs84_bounds: tuple[float, float, float, float] | None = None, + batch_size: int = 1, +) -> None: + """Write jobs for the specified application and time range. + + Args: + application: which application to run. + time_range: the time range to run within. Must have timezone. + out_path: the output path. It should be specific to the time range. + job_fname: where to write the list of jobs for workers to read. + epsg_code: limit tasks to this UTM zone (specified by its EPSG code), default + run in all UTM zones. + wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. + batch_size: how many tasks to run in each batch. + """ + jobs = get_jobs( + application, + time_range, + out_path, + epsg_code=epsg_code, + wgs84_bounds=wgs84_bounds, + batch_size=batch_size, + ) + with UPath(job_fname).open("w") as f: + json.dump(jobs, f) + + +def write_jobs_for_year_months( + year_months: list[tuple[int, int]], + application: Application, + out_path: str, + job_fname: str, + batch_size: int = 1, +) -> None: + """Write Satlas prediction jobs for the given year and month. + + Args: + year_months: list of year-month pairs. + application: the application to run. + out_path: the output path with year and month placeholders. + job_fname: where to write the list of jobs for workers to read. + worker_params: the worker parameters. + batch_size: the batch size. + """ + jobs = [] + for year, month in year_months: + ts = datetime(year, month, 1, tzinfo=timezone.utc) + time_range = ( + ts - timedelta(days=DAYS_BEFORE), + ts + timedelta(days=DAYS_AFTER), + ) + cur_out_path = out_path.format(year=year, month=month) + logger.info( + f"collecting jobs for year={year}, month={month}, time_range={time_range}, out_path={cur_out_path}" + ) + cur_jobs = get_jobs( + application=application, + time_range=time_range, + out_path=cur_out_path, + batch_size=batch_size, + ) + logger.info("got %d jobs for %04d-%02d", len(cur_jobs), year, month) + jobs.extend(cur_jobs) + + logger.info("got a total of %d jobs across year-months", len(jobs)) + with UPath(job_fname).open("w") as f: + json.dump(jobs, f) + + +def launch_workers(worker_params: WorkerParams, num_workers: int) -> None: + """Start workers for the prediction jobs. + + Args: + worker_params: the parameters for the workers, including job file where the + list of jobs has been written. + num_workers: number of workers to launch + """ + for _ in tqdm.tqdm(range(num_workers)): + launch_worker(worker_params) diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py new file mode 100644 index 00000000..62fd5733 --- /dev/null +++ b/rslp/satlas/postprocess.py @@ -0,0 +1,236 @@ +"""Postprocessing outputs from Satlas models.""" + +import json +import math +import multiprocessing +import shutil +import subprocess # nosec +import tempfile +from typing import Any + +import shapely +import tqdm +from rslearn.const import WGS84_PROJECTION +from rslearn.utils.geometry import Projection, STGeometry +from rslearn.utils.grid_index import GridIndex +from upath import UPath + +from rslp.log_utils import get_logger + +from .predict_pipeline import Application + +# Approximate maximum meters in one degree latitude/longitude. +MAX_METERS_PER_DEGREE = 111111 + +# Threshold on Euclidean distance between lat/lon for NMS. +# We just do Euclidean distance for speed/simplicity since NMS doesn't need to be super +# exact. +NMS_DISTANCE_THRESHOLD = 100 / MAX_METERS_PER_DEGREE + +logger = get_logger(__name__) + + +def _get_fc(fname: UPath) -> dict[str, Any]: + with fname.open() as f: + return json.load(f) + + +def apply_nms( + features: list[dict[str, Any]], + distance_threshold: float, +) -> list[dict[str, Any]]: + """Apply non-maximum suppression over the points. + + Args: + features: the list of JSON Feature objects. + distance_threshold: the distance threshold to match points. + + Returns: + new Features with NMS applied. + """ + # A few multiples of the distance threshold is generally a good grid size. + grid_index = GridIndex(distance_threshold * 10) + + # Insert features into the index. + for idx, feat in enumerate(features): + coordinates = feat["geometry"]["coordinates"] + box = (coordinates[0], coordinates[1], coordinates[0], coordinates[1]) + grid_index.insert(box, idx) + + good_features = [] + for idx, feat in enumerate(features): + coordinates = feat["geometry"]["coordinates"] + # Create search box with distance threshold padding. + box = ( + coordinates[0] - distance_threshold, + coordinates[1] - distance_threshold, + coordinates[0] + distance_threshold, + coordinates[1] + distance_threshold, + ) + is_feat_okay = True + for other_idx in grid_index.query(box): + other_feat = features[other_idx] + if idx == other_idx: + continue + if feat["properties"]["score"] < other_feat["properties"]["score"]: + continue + other_coordinates = other_feat["geometry"]["coordinates"] + distance = math.sqrt( + (coordinates[0] - other_coordinates[0]) ** 2 + + (coordinates[1] - other_coordinates[1]) ** 2 + ) + if distance > distance_threshold: + continue + is_feat_okay = False + break + + if is_feat_okay: + good_features.append(feat) + + return good_features + + +def postprocess_points( + application: Application, + label: str, + predict_path: str, + merged_path: str, + smoothed_path: str, + workers: int = 32, +) -> None: + """Post-process Satlas point outputs. + + This merges the outputs across different prediction tasks for this timestamp and + spatial tile. Then it applies Viterbi smoothing that takes into account merged + outputs from previous time ranges, and uploads the results. + + Args: + application: the application. + label: YYYY-MM representation of the time range used for this prediction run. + predict_path: output path of the prediction pipeline where GeoJSONs from all + the different tasks have been written. + merged_path: folder to write merged predictions. The filename will be + YYYY-MM.geojson. + smoothed_path: folder to write smoothed predictions. The filename will be + YYYY-MM.geojson. + workers: number of worker processes. + """ + # Merge the predictions. + predict_upath = UPath(predict_path) + merged_features = [] + merged_patches: dict[str, list[tuple[int, int]]] = {} + + fnames = [fname for fname in predict_upath.iterdir() if fname.name != "index"] + p = multiprocessing.Pool(workers) + outputs = p.imap_unordered(_get_fc, fnames) + + for cur_fc in tqdm.tqdm(outputs, total=len(fnames)): + # The projection information may be missing if there are no valid patches. + if "crs" not in cur_fc["properties"]: + # Just do some sanity checks, there should be no features and no valid + # patches. + assert len(cur_fc["features"]) == 0 + patch_list = list(cur_fc["properties"]["valid_patches"].values()) + assert len(patch_list) == 1 and len(patch_list[0]) == 0 + continue + + src_projection = Projection.deserialize(cur_fc["properties"]) + crs_str = str(src_projection.crs) + + # We ultimately want to store longitude/latitude but + # smooth_point_labels_viterbi.go needs to know the projection and x/y so we + # write them as properties of the feature, while converting the geometry + # coordinates to WGS84. + for feat in cur_fc["features"]: + col, row = feat["geometry"]["coordinates"] + feat["properties"]["col"] = int(col) + feat["properties"]["row"] = int(row) + feat["properties"]["projection"] = crs_str + + src_geom = STGeometry(src_projection, shapely.Point(col, row), None) + dst_geom = src_geom.to_projection(WGS84_PROJECTION) + feat["geometry"]["coordinates"] = [dst_geom.shp.x, dst_geom.shp.y] + + merged_features.append(feat) + + # Merge the valid patches too, these indicate which portions of the world + # actually had image content for the current timestep. + assert len(cur_fc["properties"]["valid_patches"]) == 1 + if crs_str not in merged_patches: + merged_patches[crs_str] = [] + merged_patches[crs_str].extend(cur_fc["properties"]["valid_patches"][crs_str]) + + p.close() + + nms_features = apply_nms(merged_features, distance_threshold=NMS_DISTANCE_THRESHOLD) + logger.info( + "NMS filtered from %d -> %d features", len(merged_features), len(nms_features) + ) + + merged_upath = UPath(merged_path) + merged_fname = merged_upath / f"{label}.geojson" + with merged_fname.open("w") as f: + json.dump( + { + "type": "FeatureCollection", + "features": nms_features, + "properties": { + "valid_patches": merged_patches, + }, + }, + f, + ) + + # Download the merged prediction history (ending with the one we just wrote) and + # run smoothing. + smoothed_upath = UPath(smoothed_path) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_upath = UPath(tmp_dir) + tmp_merged_dir = tmp_upath / "merged" + tmp_smoothed_dir = tmp_upath / "smoothed" + tmp_hist_fname = tmp_upath / "history.geojson" + + tmp_merged_dir.mkdir() + tmp_smoothed_dir.mkdir() + + labels: list[str] = [] + for merged_fname in merged_upath.iterdir(): + # Get the label like 2024-01 from 2024-01.geojson. + if not merged_fname.name.endswith(".geojson"): + continue + label = merged_fname.name.split(".")[0] + + local_fname = tmp_merged_dir / merged_fname.name + with merged_fname.open("rb") as src: + with local_fname.open("wb") as dst: + shutil.copyfileobj(src, dst) + labels.append(label) + + # Sort by YYYY-MM. + labels.sort() + + subprocess.check_call( + [ + "rslp/satlas/scripts/smooth_point_labels_viterbi", + "--labels", + ",".join(labels), + "--fname", + (tmp_merged_dir / "LABEL.geojson").path, + "--out", + (tmp_smoothed_dir / "LABEL.geojson").path, + "--hist", + tmp_hist_fname.path, + ], + ) # nosec + + for label in labels: + src_path = tmp_smoothed_dir / f"{label}.geojson" + dst_path = smoothed_upath / f"{label}.geojson" + with src_path.open("rb") as src: + with dst_path.open("wb") as dst: + shutil.copyfileobj(src, dst) + + dst_path = smoothed_upath / "history.geojson" + with tmp_hist_fname.open("rb") as src: + with dst_path.open("wb") as dst: + shutil.copyfileobj(src, dst) diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py index db53e80b..fe2f5e6e 100644 --- a/rslp/satlas/predict_pipeline.py +++ b/rslp/satlas/predict_pipeline.py @@ -7,8 +7,8 @@ from enum import Enum from typing import Any +import rslearn.data_sources.copernicus from rslearn.const import WGS84_PROJECTION -from rslearn.data_sources.copernicus import load_sentinel2_tile_index from rslearn.dataset import Window from rslearn.utils.geometry import PixelBounds, Projection from upath import UPath @@ -21,10 +21,6 @@ SENTINEL2_LAYER = "sentinel2" PATCH_SIZE = 2048 -# Add padding to the time range specified by the user for prediction since some -# applications use images from up to this many days outside of that time range. -RTREE_TIME_PAD_DAYS = 30 - # Layers not to use when seeing which patches are valid. VALIDITY_EXCLUDE_LAYERS = ["mask", "output", "label"] @@ -120,15 +116,15 @@ def predict_pipeline( if "data_source" not in layer_cfg: continue layer_source_cfg = layer_cfg["data_source"] - if not layer_source_cfg["name"].endswith("gcp_public_data.Sentinel2"): + if not layer_source_cfg["name"].endswith("MonthlySentinel2"): continue layer_source_cfg["index_cache_dir"] = str(index_cache_dir) - # layer_source_cfg["rtree_cache_dir"] = str(UPath(out_path) / "index") - # layer_source_cfg["use_rtree_index"] = True - # layer_source_cfg["rtree_time_range"] = [ - # (time_range[0] - timedelta(days=RTREE_TIME_PAD_DAYS)).isoformat(), - # (time_range[1] + timedelta(days=RTREE_TIME_PAD_DAYS)).isoformat(), - # ] + layer_source_cfg["rtree_cache_dir"] = str(UPath(out_path) / "index") + layer_source_cfg["use_rtree_index"] = True + layer_source_cfg["rtree_time_range"] = [ + time_range[0].isoformat(), + time_range[1].isoformat(), + ] image_layer_names.append(layer_name) with (ds_path / "config.json").open("w") as f: @@ -197,11 +193,11 @@ def predict_pipeline( # This way it is only downloaded once here instead of many times during prepare. # We could set use_initial_prepare_job=True in materialize_dataset call, but then # it could take a minute or more longer than needed. - load_sentinel2_tile_index(index_cache_dir) + rslearn.data_sources.copernicus._cache_sentinel2_tile_index(index_cache_dir) # Populate the windows. logger.info("materialize dataset") - materialize_dataset(ds_path, group=group, prepare_workers=128) + materialize_dataset(ds_path, group=group, initial_prepare_job=True) # Run the model, only if at least one window has some data. completed_fnames = ds_path.glob( diff --git a/rslp/satlas/scripts/go.mod b/rslp/satlas/scripts/go.mod new file mode 100644 index 00000000..dc30d1fb --- /dev/null +++ b/rslp/satlas/scripts/go.mod @@ -0,0 +1,12 @@ +module main + +go 1.22.5 + +require github.com/mitroadmaps/gomapinfer v0.0.0-20210917033103-4e3dcc98a112 + +require ( + github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b // indirect + github.com/dhconnelly/rtreego v1.1.0 // indirect + github.com/qedus/osmpbf v1.2.0 // indirect + google.golang.org/protobuf v1.26.0 // indirect +) diff --git a/rslp/satlas/scripts/go.sum b/rslp/satlas/scripts/go.sum new file mode 100644 index 00000000..69204293 --- /dev/null +++ b/rslp/satlas/scripts/go.sum @@ -0,0 +1,43 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= +github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= +github.com/dhconnelly/rtreego v1.1.0 h1:ejMaqN03N1s6Bdg6peGkNgBnYYSBHzcK8yhSPCB+rHE= +github.com/dhconnelly/rtreego v1.1.0/go.mod h1:SDozu0Fjy17XH1svEXJgdYq8Tah6Zjfa/4Q33Z80+KM= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/mitroadmaps/gomapinfer v0.0.0-20210917033103-4e3dcc98a112 h1:/msesy1s0b8A/xU7MSCzcDvDcrVaqvE/7Bckq1y9kAM= +github.com/mitroadmaps/gomapinfer v0.0.0-20210917033103-4e3dcc98a112/go.mod h1:60dnxKwUjhRjPfhvtjJQccZOo741GN+5WymEEc+Aa0c= +github.com/qedus/osmpbf v1.2.0 h1:yRm5ECkiUsN9sA+UN9yNnm64AVW2OYhOCb+gBa1FYCU= +github.com/qedus/osmpbf v1.2.0/go.mod h1:Cfv6JyqTZ72BjoW9FyFBQOC2DYJbL78yw+DLhBvSH+M= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= diff --git a/rslp/satlas/scripts/smooth_point_labels_viterbi.go b/rslp/satlas/scripts/smooth_point_labels_viterbi.go index 696365a0..0c4dbb18 100644 --- a/rslp/satlas/scripts/smooth_point_labels_viterbi.go +++ b/rslp/satlas/scripts/smooth_point_labels_viterbi.go @@ -4,7 +4,6 @@ import ( "encoding/json" "flag" "fmt" - "io/ioutil" "log" "math" "os" @@ -15,6 +14,7 @@ import ( ) const FUTURE_LABEL = "2030-01" +const TILE_SIZE = 2048 type Tile struct { Projection string @@ -26,16 +26,16 @@ type Point struct { Geometry struct { Type string `json:"type"` Coordinates [2]float64 `json:"coordinates"` - } + } `json:"geometry"` label string Properties struct { - Category string `json:"category"` - Score float64 `json:"score"` - Projection string `json:"projection,omitempty"` - Column int `json:"column,omitempty"` - Row int `json:"row,omitempty"` - Start string `json:"start,omitempty"` - End string `json:"end,omitempty"` + Category *string `json:"category"` + Score *float64 `json:"score"` + Projection *string `json:"projection,omitempty"` + Column *int `json:"col,omitempty"` + Row *int `json:"row,omitempty"` + Start string `json:"start,omitempty"` + End string `json:"end,omitempty"` } `json:"properties"` } @@ -43,7 +43,7 @@ type PointData struct { Type string `json:"type"` Features []Point `json:"features"` Properties struct { - ValidPatches map[string][][2]int `json:"valid_patches"` + ValidPatches map[string][][2]int `json:"valid_patches,omitempty"` } `json:"properties"` } @@ -52,8 +52,8 @@ type Group []Point func (g Group) Center() [2]int { var sum [2]int for _, p := range g { - sum[0] += p.Properties.Column - sum[1] += p.Properties.Row + sum[0] += *p.Properties.Column + sum[1] += *p.Properties.Row } return [2]int{ sum[0] / len(g), @@ -113,7 +113,7 @@ func main() { if _, err := os.Stat(fname); os.IsNotExist(err) { continue } - bytes, err := ioutil.ReadFile(fname) + bytes, err := os.ReadFile(fname) if err != nil { panic(err) } @@ -130,9 +130,9 @@ func main() { } gridIndexes := make(map[string]*common.GridIndex) for idx, point := range curPoints { - projection := point.Properties.Projection - col := float64(point.Properties.Column) - row := float64(point.Properties.Row) + projection := *point.Properties.Projection + col := float64(*point.Properties.Column) + row := float64(*point.Properties.Row) if gridIndexes[projection] == nil { gridIndexes[projection] = common.NewGridIndex(GridSize) } @@ -147,7 +147,7 @@ func main() { // Match existing groups to the new points. matchedIndices := make(map[int]bool) for groupIdx, group := range groups { - projection := group[0].Properties.Projection + projection := *group[0].Properties.Projection center := group.Center() indices := gridIndexes[projection].Search(common.Rectangle{ Min: common.Point{float64(center[0]) - GridSize, float64(center[1]) - GridSize}, @@ -159,12 +159,12 @@ func main() { if matchedIndices[idx] { continue } - if group[0].Properties.Category != curPoints[idx].Properties.Category { + if *group[0].Properties.Category != *curPoints[idx].Properties.Category { continue } - dx := center[0] - curPoints[idx].Properties.Column - dy := center[1] - curPoints[idx].Properties.Row + dx := center[0] - *curPoints[idx].Properties.Column + dy := center[1] - *curPoints[idx].Properties.Row distance := math.Sqrt(float64(dx*dx + dy*dy)) if distance > *distanceThreshold/MetersPerPixel { @@ -310,9 +310,9 @@ func main() { validLabelSet := make(map[string]bool) center := group.Center() tile := Tile{ - Projection: group[0].Properties.Projection, - Column: center[0] / 512, - Row: center[1] / 512, + Projection: *group[0].Properties.Projection, + Column: int(math.Floor(float64(center[0]) / TILE_SIZE)), + Row: int(math.Floor(float64(center[1]) / TILE_SIZE)), } for _, label := range tileLabelValidity[tile] { validLabelSet[label] = true @@ -371,7 +371,9 @@ func main() { for labelIdx := rng.StartIdx; labelIdx < rng.EndIdx; labelIdx++ { label := labelList[labelIdx] if outFeatures[label] == nil { - outFeatures[label] = &PointData{} + outFeatures[label] = &PointData{ + Type: "FeatureCollection", + } } outFeatures[label].Features = append(outFeatures[label].Features, feat) } @@ -384,9 +386,10 @@ func main() { var scoreSum float64 = 0 for _, p := range rng.Group { - scoreSum += p.Properties.Score + scoreSum += *p.Properties.Score } - feat.Properties.Score = scoreSum / float64(len(rng.Group)) + scoreAvg := scoreSum / float64(len(rng.Group)) + feat.Properties.Score = &scoreAvg historyData.Features = append(historyData.Features, feat) } @@ -399,7 +402,7 @@ func main() { if err != nil { panic(err) } - if err := ioutil.WriteFile(*histFname, bytes, 0644); err != nil { + if err := os.WriteFile(*histFname, bytes, 0644); err != nil { panic(err) } } @@ -411,7 +414,7 @@ func main() { if err != nil { panic(err) } - if err := ioutil.WriteFile(fname, bytes, 0644); err != nil { + if err := os.WriteFile(fname, bytes, 0644); err != nil { panic(err) } } From f6bbc780cc9b8292bd5730ea6476869973565aa3 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 10 Dec 2024 11:17:14 -0800 Subject: [PATCH 19/58] sync --- requirements.txt | 1 + rslp/satlas/__init__.py | 7 +- rslp/satlas/bkt.py | 406 ++++++++++++++++++ rslp/satlas/postprocess.py | 54 ++- rslp/satlas/publish.py | 234 ++++++++++ .../scripts/smooth_point_labels_viterbi.go | 89 +++- 6 files changed, 767 insertions(+), 24 deletions(-) create mode 100644 rslp/satlas/bkt.py create mode 100644 rslp/satlas/publish.py diff --git a/requirements.txt b/requirements.txt index 4ab931dd..8169b9aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ beaker-py>=1.32 fastapi>=0.115 +google-cloud-bigtable>=2.18 interrogate>=1.7 pydantic>=2.8 pytest>=8.2 diff --git a/rslp/satlas/__init__.py b/rslp/satlas/__init__.py index fc043dc1..f0439050 100644 --- a/rslp/satlas/__init__.py +++ b/rslp/satlas/__init__.py @@ -8,8 +8,9 @@ """ from .job_launcher_worker import launch_workers, write_jobs, write_jobs_for_year_months -from .postprocess import postprocess_points +from .postprocess import merge_points, smooth_points from .predict_pipeline import predict_multi, predict_pipeline +from .publish import publish_points workflows = { "predict": predict_pipeline, @@ -17,5 +18,7 @@ "write_jobs": write_jobs, "write_jobs_for_year_months": write_jobs_for_year_months, "launch_workers": launch_workers, - "postprocess_points": postprocess_points, + "merge_points": merge_points, + "smooth_points": smooth_points, + "publish_points": publish_points, } diff --git a/rslp/satlas/bkt.py b/rslp/satlas/bkt.py new file mode 100644 index 00000000..e2784457 --- /dev/null +++ b/rslp/satlas/bkt.py @@ -0,0 +1,406 @@ +"""Manage bucket files on GCS. + +We bucket together small (10-200 KB) files at high zoom levels (e.g. zoom 13) into a +single file at a lower zoom level (e.g. zoom 9) to save on GCS insert fee. + +This is similar to https://github.com/mactrem/com-tiles. + +The .bkt is just a concatenation of the small files. + +We record the byte offsets in a Google Cloud Bigtable database. +""" + +import functools +import io +import multiprocessing.pool +import os +import struct +import time +from collections.abc import Generator +from typing import Any + +import google.cloud.bigtable.row +import google.cloud.bigtable.row_filters +import google.cloud.bigtable.table +import numpy.typing as npt +import skimage.io +from google.cloud import bigtable, storage +from rslearn.utils.mp import star_imap_unordered + +from rslp.log_utils import get_logger + +logger = get_logger(__name__) + + +class BktInserter: + """A helper class that inserts metadata about bkt files into the database. + + The BktInserter is a separate class from BktWriter so that it can be pickled to + support use with multiprocessing. + """ + + def __init__( + self, + indexes: list[tuple[int, int, int, int]], + bkt_fname: str, + bkt_zoom: int, + zoom: int, + ): + """Create a new BktInserter. + + Args: + indexes: the byte offsets of the files within the bkt. It is a list of + (col, row, offset, length) tuples. + bkt_fname: the filename where the bkt will be written. + bkt_zoom: the zoom level of the bkt. + zoom: the zoom level of the tiles within the bkt. + """ + self.indexes = indexes + self.bkt_fname = bkt_fname + self.bkt_zoom = bkt_zoom + self.zoom = zoom + + def run(self, bkt_files_table: google.cloud.bigtable.table.Table) -> None: + """Insert the metadata into BigTable. + + Args: + bkt_files_table: the BigTable object + """ + # Row key in the table is just the bkt fname. + # Value is [4 byte bkt_zoom][4 byte zoom][indexes]. + # [indexes] is list of indexes encoded as [4 byte col][4 byte row][4 byte offset][4 byte length]. + buf = io.BytesIO() + buf.write(struct.pack(">II", self.bkt_zoom, self.zoom)) + for col, row, offset, length in self.indexes: + buf.write(struct.pack(">IIII", col, row, offset, length)) + db_row = bkt_files_table.direct_row(self.bkt_fname) + db_row.set_cell(b"d", b"d", buf.getvalue()) + db_row.commit() + + +class BktWriter: + """Writer for bkt files.""" + + def __init__(self) -> None: + """Create a new BktWriter.""" + self.indexes: list[tuple[int, int, int, int]] = [] + self.buf = io.BytesIO() + self.offset = 0 + + def add(self, col: int, row: int, bytes: bytes) -> None: + """Add a file to the bkt. + + Args: + col: the tile column. + row: the tile row. + bytes: the data at this tile. + """ + offset = self.offset + length = len(bytes) + self.indexes.append((col, row, offset, length)) + self.buf.write(bytes) + self.offset += length + + def get_bytes(self) -> bytes: + """Returns the bytes of the whole bkt file.""" + return self.buf.getvalue() + + def get_inserter(self, bkt_fname: str, bkt_zoom: int, zoom: int) -> "BktInserter": + """Creates a BktInserter that manages inserting the byte offsets to BigTable. + + Args: + bkt_fname: the filename where the bkt will be written. + bkt_zoom: the zoom level of the bkt file. + zoom: the zoom of the tiles within the bkt file. + + Returns: + a corresponding BktInserter + """ + return BktInserter(self.indexes, bkt_fname, bkt_zoom, zoom) + + def insert( + self, + bkt_files_table: google.cloud.bigtable.table.Table, + bkt_fname: str, + bkt_zoom: int, + zoom: int, + ) -> None: + """Insert the byte offsets for this bkt to BigTable. + + Args: + bkt_files_table: the BigTable table object. + bkt_fname: the filename where the bkt will be written. + bkt_zoom: the zoom level of the bkt file. + zoom: the zoom of the tiles within the bkt file. + """ + self.get_inserter(bkt_fname, bkt_zoom, zoom).run(bkt_files_table) + + +@functools.cache +def get_bucket() -> storage.Bucket: + """Get the GCS bucket where bkt files should be stored.""" + storage_client = storage.Client(project=os.environ["BKT_PROJECT_ID"]) + bucket = storage_client.bucket(os.environ["BKT_BUCKET_NAME"]) + return bucket + + +def download_bkt( + bkt_fname: str, + idx_map: dict[tuple[int, int], tuple[int, int]], + wanted: list[tuple[int, int, Any]], + mode: str, +) -> list[tuple[Any, npt.NDArray | bytes]]: + """Download tiles in a bkt file. + + Args: + bkt_fname: the bkt filename in the bucket to download from. + idx_map: map from tile (col, row) to (offset, length). + wanted: list of tiles to download. It should be a list of (col, row, metadata) + where metadata can be arbitrary data used by the caller that will be + returned with the tile data (which will be emitted in arbitrary order). + Note that if a tile does not exist within the bkt, it will not be returned + at all. + mode: either "image" to decode image and return numpy array, or "raw" to return + the byte string directly. + + Returns: + a list of (metadata, contents) where contents is a numpy array if mode is + "image" or a byte string if mode is "raw". + """ + bucket = get_bucket() + output = [] + + # Helper to postprocess an output based on the specified return mode. + def add_output(metadata: Any, contents: npt.NDArray | bytes) -> None: + if mode == "image": + buf = io.BytesIO(contents) + image = skimage.io.imread(buf) + output.append((metadata, image)) + + elif mode == "raw": + output.append((metadata, contents)) + + else: + raise ValueError(f"invalid mode {mode}") + + wanted = [ + (col, row, metadata) for col, row, metadata in wanted if (col, row) in idx_map + ] + + if len(wanted) == 1: + col, row, metadata = wanted[0] + offset, length = idx_map[(col, row)] + blob = bucket.blob(bkt_fname) + contents = blob.download_as_bytes(start=offset, end=offset + length) + add_output(metadata, contents) + + elif len(wanted) > 1: + blob = bucket.blob(bkt_fname) + bkt_bytes = blob.download_as_bytes() + for col, row, metadata in wanted: + offset, length = idx_map[(col, row)] + contents = bkt_bytes[offset : offset + length] + add_output(metadata, contents) + + return output + + +# Parallel download from various bkt files. +# Jobs is a list of (bkt_fname, col, row, metadata). +# download_from_bkt is a generator that will yield (metadata, bytes) for each provided job. +def download_from_bkt( + bkt_files_table: google.cloud.bigtable.table.Table, + pool: multiprocessing.pool.Pool | None, + jobs: list[tuple[str, int, int, Any]], + mode: str = "raw", +) -> Generator[tuple[Any, npt.NDArray | bytes], None, None]: + """Download tile contents in parallel from several bkt files. + + Args: + bkt_files_table: the BigTable table containing byte offsets. + pool: the multiprocessing pool to use for parallelism, or None to read in + current process. + jobs: list of (bkt_fname, col, row, metadata) to work through. Jobs referencing + the same bkt_fname will be grouped together so we don't read the same bkt + file multiple times. + mode: the return mode (see download_bkt). + + Yields: + the (metadata, contents) tuples across all of the jobs. + """ + # Get indexes associated with each distinct bkt_fname. + by_bkt_fname: dict[str, list[tuple[int, int, Any]]] = {} + for bkt_fname, col, row, metadata in jobs: + if bkt_fname not in by_bkt_fname: + by_bkt_fname[bkt_fname] = [] + by_bkt_fname[bkt_fname].append((col, row, metadata)) + + bkt_jobs: list[dict[str, Any]] = [] + for bkt_fname, wanted in by_bkt_fname.items(): + # Use retry loop since we seem to get error reading from BigTable occasionally. + def bkt_retry_loop() -> google.cloud.bigtable.row.PartialRowData: + for _ in range(8): + try: + db_row = bkt_files_table.read_row( + bkt_fname, + filter_=google.cloud.bigtable.row_filters.CellsColumnLimitFilter( + 1 + ), + ) + return db_row + except Exception as e: + print( + f"got error reading bkt_files_table for {bkt_fname} (trying again): {e}" + ) + time.sleep(1) + raise Exception( + f"repeatedly failed to read bkt_files_table for {bkt_fname}" + ) + + db_row = bkt_retry_loop() + + # Ignore requested files that don't exist. + if not db_row: + continue + # Skip 8-byte header with bkt_zoom/zoom. + encoded_indexes = db_row.cell_value("d", b"d")[8:] + + indexes = {} + for i in range(0, len(encoded_indexes), 16): + col, row, offset, length = struct.unpack( + ">IIII", encoded_indexes[i : i + 16] + ) + indexes[(col, row)] = (offset, length) + bkt_jobs.append( + dict( + bkt_fname=bkt_fname, + idx_map=indexes, + wanted=wanted, + mode=mode, + ) + ) + + if pool is None: + for job in bkt_jobs: + for metadata, image in download_bkt(**job): + yield (metadata, image) + else: + outputs = star_imap_unordered(pool, download_bkt, bkt_jobs) + for output in outputs: + for metadata, image in output: + yield (metadata, image) + + +def upload_bkt(bkt_fname: str, contents: bytes) -> None: + """Upload a bkt file to GCS bucket. + + Args: + bkt_fname: the bkt filename within the bucket. + contents: the data to upload. + """ + bucket = get_bucket() + blob = bucket.blob(bkt_fname) + blob.upload_from_string(contents) + + +# Tuples is list of (bkt_writer, bkt_fname, bkt_zoom, zoom). +def upload_bkts( + bkt_files_table: google.cloud.bigtable.table.Table, + pool: multiprocessing.pool.Pool, + jobs: list[tuple[BktWriter, str, int, int]], +) -> None: + """Upload several bkt files to GCS in parallel. + + Args: + bkt_files_table: the BigTable table to store byte offsets. + pool: a multiprocessing pool for parallelism. + jobs: list of (bkt_writer, bkt_fname, bkt_zoom, zoom) tuples. bkt_writer is the + BktWriter where the bkt contents and metadata are stored. bkt_fname is the + path where the bkt should be written. bkt_zoom in the zoom level of the bkt + file. zoom is the zoom level of tiles within the bkt. + """ + # Upload. We upload first since reader will assume that anything existing in + # BigTable already exists on GCS. + upload_jobs: list[tuple[str, bytes]] = [] + for bkt_writer, bkt_fname, bkt_zoom, zoom in jobs: + upload_jobs.append((bkt_fname, bkt_writer.get_bytes())) + outputs = star_imap_unordered(pool, upload_bkt, upload_jobs) + for _ in outputs: + pass + # Now we insert the metadata. + for bkt_writer, bkt_fname, bkt_zoom, zoom in jobs: + bkt_writer.insert( + bkt_files_table=bkt_files_table, + bkt_fname=bkt_fname, + bkt_zoom=bkt_zoom, + zoom=zoom, + ) + + +def make_bkt(src_dir: str, dst_path: str) -> None: + """Make a bkt file from the specified local source directory. + + The source directory must contain files of the form zoom/col/row.ext (the extension + is ignored). + + A single bkt file is created, so the zoom level of the bkt is always 0. + + Args: + src_dir: the local directory to turn into a single bkt file. + dst_path: the bkt filename in the bkt GCS bucket to write to. It must have a + {zoom} placeholder where the zoom goes. + """ + bucket = get_bucket() + bigtable_client = bigtable.Client(project=os.environ["BKT_BIGTABLE_PROJECT_ID"]) + bigtable_instance = bigtable_client.instance(os.environ["BKT_BIGTABLE_INSTANCE_ID"]) + bkt_files_table = bigtable_instance.table("bkt_files") + + for zoom_str in os.listdir(src_dir): + zoom_dir = os.path.join(src_dir, zoom_str) + if not os.path.isdir(zoom_dir): + continue + zoom = int(zoom_str) + logger.debug( + "make_bkt(%s, %s): start collecting files at zoom level %d", + src_dir, + dst_path, + zoom, + ) + + # Read all files at this zoom level from local path into bkt (in memory). + bkt_writer = BktWriter() + num_files = 0 + for col_str in os.listdir(zoom_dir): + col_dir = os.path.join(zoom_dir, col_str) + col = int(col_str) + for fname in os.listdir(col_dir): + row = int(fname.split(".")[0]) + num_files += 1 + with open(os.path.join(col_dir, fname), "rb") as f: + contents = f.read() + bkt_writer.add(col, row, contents) + logger.debug( + "make_bkt(%s, %s): processed %d files at zoom %d", + src_dir, + dst_path, + num_files, + zoom, + ) + + # Now upload to GCS. + bkt_fname = dst_path.format(zoom=zoom) + logger.debug( + "make_bkt(%s, %s) uploading bkt for zoom level %d to %s", + src_dir, + dst_path, + zoom, + bkt_fname, + ) + blob = bucket.blob(bkt_fname) + blob.upload_from_string(bkt_writer.get_bytes()) + bkt_writer.insert( + bkt_files_table=bkt_files_table, + bkt_fname=bkt_fname, + bkt_zoom=0, + zoom=zoom, + ) diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py index 62fd5733..2d673682 100644 --- a/rslp/satlas/postprocess.py +++ b/rslp/satlas/postprocess.py @@ -27,6 +27,13 @@ # exact. NMS_DISTANCE_THRESHOLD = 100 / MAX_METERS_PER_DEGREE +APP_CATEGORY_MAPS = { + Application.MARINE_INFRA: { + "platform": "offshore_platform", + "turbine": "offshore_wind_turbine", + } +} + logger = get_logger(__name__) @@ -90,19 +97,16 @@ def apply_nms( return good_features -def postprocess_points( +def merge_points( application: Application, label: str, predict_path: str, merged_path: str, - smoothed_path: str, workers: int = 32, ) -> None: - """Post-process Satlas point outputs. + """Merge Satlas point outputs. - This merges the outputs across different prediction tasks for this timestamp and - spatial tile. Then it applies Viterbi smoothing that takes into account merged - outputs from previous time ranges, and uploads the results. + This merges the outputs across different prediction tasks for this timestamp. Args: application: the application. @@ -111,11 +115,8 @@ def postprocess_points( the different tasks have been written. merged_path: folder to write merged predictions. The filename will be YYYY-MM.geojson. - smoothed_path: folder to write smoothed predictions. The filename will be - YYYY-MM.geojson. workers: number of worker processes. """ - # Merge the predictions. predict_upath = UPath(predict_path) merged_features = [] merged_patches: dict[str, list[tuple[int, int]]] = {} @@ -124,6 +125,9 @@ def postprocess_points( p = multiprocessing.Pool(workers) outputs = p.imap_unordered(_get_fc, fnames) + # Get category remapping in case one is specified for this application. + category_map = APP_CATEGORY_MAPS.get(application, {}) + for cur_fc in tqdm.tqdm(outputs, total=len(fnames)): # The projection information may be missing if there are no valid patches. if "crs" not in cur_fc["properties"]: @@ -151,6 +155,10 @@ def postprocess_points( dst_geom = src_geom.to_projection(WGS84_PROJECTION) feat["geometry"]["coordinates"] = [dst_geom.shp.x, dst_geom.shp.y] + category = feat["properties"]["category"] + if category in category_map: + feat["properties"]["category"] = category_map[category] + merged_features.append(feat) # Merge the valid patches too, these indicate which portions of the world @@ -162,18 +170,13 @@ def postprocess_points( p.close() - nms_features = apply_nms(merged_features, distance_threshold=NMS_DISTANCE_THRESHOLD) - logger.info( - "NMS filtered from %d -> %d features", len(merged_features), len(nms_features) - ) - merged_upath = UPath(merged_path) merged_fname = merged_upath / f"{label}.geojson" with merged_fname.open("w") as f: json.dump( { "type": "FeatureCollection", - "features": nms_features, + "features": merged_features, "properties": { "valid_patches": merged_patches, }, @@ -181,6 +184,27 @@ def postprocess_points( f, ) + +def smooth_points( + application: Application, + label: str, + merged_path: str, + smoothed_path: str, +) -> None: + """Smooth the Satlas point outputs. + + It applies Viterbi smoothing that takes into account merged outputs from previous + time ranges, and uploads the results. + + Args: + application: the application. + label: YYYY-MM representation of the time range used for this prediction run. + merged_path: folder to write merged predictions. The filename will be + YYYY-MM.geojson. + smoothed_path: folder to write smoothed predictions. The filename will be + YYYY-MM.geojson. + """ + merged_upath = UPath(merged_path) # Download the merged prediction history (ending with the one we just wrote) and # run smoothing. smoothed_upath = UPath(smoothed_path) diff --git a/rslp/satlas/publish.py b/rslp/satlas/publish.py new file mode 100644 index 00000000..ace959f1 --- /dev/null +++ b/rslp/satlas/publish.py @@ -0,0 +1,234 @@ +"""Publish Satlas outputs.""" + +import json +import os +import shutil +import subprocess # nosec +import tempfile +import zipfile +from typing import Any + +import boto3 +import boto3.s3 +from upath import UPath + +from rslp.log_utils import get_logger +from rslp.satlas.bkt import make_bkt + +from .predict_pipeline import Application + +logger = get_logger(__name__) + +# Number of timesteps to re-publish. +# Smoothing for points changes all of the outputs but we only upload outputs for this +# many of the most recent timesteps. +NUM_RECOMPUTE = 6 + +# Name on Cloudflare R2 for each application. +APP_NAME_ON_R2 = { + Application.MARINE_INFRA: "marine", +} + +APP_TIPPECANOE_LAYERS = { + Application.MARINE_INFRA: "marine", +} + +SHP_EXTENSIONS = [ + ".shp", + ".dbf", + ".prj", + ".shx", +] + +BKT_TILE_PATH = "output_mosaic/" + + +def get_cloudflare_r2_bucket() -> Any: + """Returns the Cloudflare R2 bucket where outputs are published.""" + s3 = boto3.resource( + "s3", + endpoint_url=os.environ["SATLAS_R2_ENDPOINT"], + aws_access_key_id=os.environ["SATLAS_R2_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["SATLAS_R2_SECRET_ACCESS_KEY"], + ) + bucket = s3.Bucket(os.environ["SATLAS_R2_BUCKET_NAME"]) + return bucket + + +def make_shapefile_zip(fname: str) -> str: + """Create zip file of the shapefile and its supporting files. + + If filename is "x" (for x.shp and supporting files) then output is "x.shp.zip". + + Args: + fname: fname without .shp extension + + Returns: + the local filename of the resulting zip file. + """ + zip_fname = fname + ".shp.zip" + basename = os.path.basename(fname) + with zipfile.ZipFile(zip_fname, "w") as z: + for ext in SHP_EXTENSIONS: + z.write(fname + ext, arcname=basename + ext) + return zip_fname + + +def update_index(bucket: Any, prefix: str) -> None: + """Update index file on Cloudflare R2. + + The index file just has list of filenames, last modified time, and md5. + + There is one index for each application folder. + + Args: + bucket: the Cloudflare R2 bucket. + prefix: the folder's prefix in the bucket. + """ + index_lines = [] + for obj in bucket.objects.filter(Prefix=prefix): + if obj.key.endswith("/index.txt"): + continue + line = "{},{},{}".format( + obj.key, obj.last_modified, obj.e_tag.split("-")[0].replace('"', "") + ) + index_lines.append(line) + index_lines.append("") + index_data = "\n".join(index_lines) + bucket.put_object( + Body=index_data.encode(), + Key=prefix + "index.txt", + ) + + +def publish_points( + application: Application, + smoothed_path: str, + version: str, + workers: int = 32, +) -> None: + """Publish Satlas point outputs. + + The points are added to two locations: GeoJSONs are added to Cloudflare R2, while + tippecanoe is used to generate vector tiles that are uploaded to GCS for use by the + satlas.allen.ai website. + + Args: + application: the application. + smoothed_path: folder containing smoothed predictions (including + history.geojson file). + version: current model version for use to distinguish different outputs on GCS. + workers: number of worker processes. + """ + smoothed_upath = UPath(smoothed_path) + + # First upload files to R2. + bucket = get_cloudflare_r2_bucket() + with tempfile.TemporaryDirectory() as tmp_dir: + # Upload history. + logger.info("upload history") + local_hist_fname = os.path.join(tmp_dir, "history.geojson") + with (smoothed_upath / "history.geojson").open("rb") as src: + with open(local_hist_fname, "wb") as dst: + shutil.copyfileobj(src, dst) + app_name_on_r2 = APP_NAME_ON_R2[application] + bucket.upload_file(local_hist_fname, f"outputs/{app_name_on_r2}/marine.geojson") + + # Upload the latest outputs too. + available_fnames: list[UPath] = [] + for fname in smoothed_upath.iterdir(): + if fname.name == "history.geojson": + continue + available_fnames.append(fname) + available_fnames.sort(key=lambda fname: fname.name) + for fname in available_fnames[-NUM_RECOMPUTE:]: + logger.info("upload %s", str(fname)) + local_geojson_fname = os.path.join(tmp_dir, "data.geojson") + # local_shp_prefix = os.path.join(tmp_dir, "shp_data") + # local_kml_fname = os.path.join(tmp_dir, "data.kml") + + with fname.open("rb") as src: + with open(local_geojson_fname, "wb") as dst: + shutil.copyfileobj(src, dst) + + """ + subprocess.check_call([ + 'ogr2ogr', + '-F', 'ESRI Shapefile', + '-nlt', 'POINT', + local_shp_prefix + ".shp", + local_geojson_fname, + ]) + make_shapefile_zip(local_shp_prefix) + subprocess.check_call([ + 'ogr2ogr', + '-F', 'KML', + local_kml_fname, + local_geojson_fname, + ]) + """ + + fname_prefix = fname.name.split(".")[0] + + bucket.upload_file( + local_geojson_fname, + f"outputs/{app_name_on_r2}/{fname_prefix}.geojson", + ) + """ + bucket.upload_file( + local_shp_prefix + ".shp.zip", + f"outputs/{app_name_on_r2}/{fname_prefix}.shp.zip", + ) + bucket.upload_file( + local_kml_fname, + f"outputs/{app_name_on_r2}/{fname_prefix}.kml", + ) + """ + if fname == available_fnames[-1]: + bucket.upload_file( + local_geojson_fname, + f"outputs/{app_name_on_r2}/latest.geojson", + ) + """ + bucket.upload_file( + local_shp_prefix + ".shp.zip", + f"outputs/{app_name_on_r2}/latest.shp.zip", + ) + bucket.upload_file( + local_kml_fname, + f"outputs/{app_name_on_r2}/latest.kml", + ) + """ + + update_index(bucket, f"outputs/{app_name_on_r2}/") + + # Generate the tippecanoe tiles. + # We set tippecanoe layer via property of each feature. + with tempfile.TemporaryDirectory() as tmp_dir: + tippecanoe_layer = APP_TIPPECANOE_LAYERS[application] + with (smoothed_upath / "history.geojson").open("rb") as f: + fc = json.load(f) + for feat in fc["features"]: + feat["tippecanoe"] = {"layer": tippecanoe_layer} + local_geojson_fname = os.path.join(tmp_dir, "history.geojson") + with open(local_geojson_fname, "w") as f: + json.dump(fc, f) + + local_tile_dir = os.path.join(tmp_dir, "tiles") + logger.info("run tippecanoe on history in local tmp dir %s", local_tile_dir) + subprocess.check_call( + [ + "tippecanoe", + "-z13", + "-r1", + "--cluster-densest-as-needed", + "--no-tile-compression", + "-e", + local_tile_dir, + local_geojson_fname, + ] + ) # nosec + + tile_dst_path = f"{BKT_TILE_PATH}{version}/history/{{zoom}}/0/0.bkt" + logger.info("make bkt at %s", tile_dst_path) + make_bkt(src_dir=local_tile_dir, dst_path=tile_dst_path) diff --git a/rslp/satlas/scripts/smooth_point_labels_viterbi.go b/rslp/satlas/scripts/smooth_point_labels_viterbi.go index 0c4dbb18..94cdd975 100644 --- a/rslp/satlas/scripts/smooth_point_labels_viterbi.go +++ b/rslp/satlas/scripts/smooth_point_labels_viterbi.go @@ -16,6 +16,11 @@ import ( const FUTURE_LABEL = "2030-01" const TILE_SIZE = 2048 +// Don't consider groups with fewer than this many valid timesteps. +// Note that the point doesn't need to be detected in all the timesteps, this is just +// timesteps where we have image coverage. +const MIN_VALID_TIMESTEPS = 8 + type Tile struct { Projection string Column int @@ -23,6 +28,7 @@ type Tile struct { } type Point struct { + Type string `json:"type"` Geometry struct { Type string `json:"type"` Coordinates [2]float64 `json:"coordinates"` @@ -96,6 +102,7 @@ func main() { outFname := flag.String("out", "", "Output filename with LABEL placeholder like out/LABEL.geojson") histFname := flag.String("hist", "", "Merged history output filename") distanceThreshold := flag.Float64("max_dist", 200, "Matching distance threshold in meters") + nmsDistance := flag.Float64("nms_dist", 200.0/111111, "NMS distance in degrees") numThreads := flag.Int("threads", 32, "Number of threads") flag.Parse() @@ -149,20 +156,33 @@ func main() { for groupIdx, group := range groups { projection := *group[0].Properties.Projection center := group.Center() - indices := gridIndexes[projection].Search(common.Rectangle{ - Min: common.Point{float64(center[0]) - GridSize, float64(center[1]) - GridSize}, - Max: common.Point{float64(center[0]) + GridSize, float64(center[1]) + GridSize}, - }) + + // Lookup candidate new points that could match this group using the grid index. + var indices []int + if gridIndexes[projection] != nil { + indices = gridIndexes[projection].Search(common.Rectangle{ + Min: common.Point{float64(center[0]) - GridSize, float64(center[1]) - GridSize}, + Max: common.Point{float64(center[0]) + GridSize, float64(center[1]) + GridSize}, + }) + } + var closestIdx int = -1 var closestDistance float64 for _, idx := range indices { if matchedIndices[idx] { continue } - if *group[0].Properties.Category != *curPoints[idx].Properties.Category { - continue - } + // Double check distance threshold since the index may still return + // points that are slightly outside the threshold. + // We used to check category too, but now we use the category of the + // last prediction, and just apply a distance penalty for mismatched + // category, since we noticed that sometimes there are partially + // constructed wind turbines detected as platforms but then later + // detected as turbines once construction is done, and we don't want + // that to mess up the Viterbi smoothing. Put another way, marine + // infrastructure should show up in our map even if we're not exactly + // sure about the category. dx := center[0] - *curPoints[idx].Properties.Column dy := center[1] - *curPoints[idx].Properties.Row distance := math.Sqrt(float64(dx*dx + dy*dy)) @@ -170,6 +190,11 @@ func main() { if distance > *distanceThreshold/MetersPerPixel { continue } + + if *group[0].Properties.Category != *curPoints[idx].Properties.Category { + distance += *distanceThreshold / MetersPerPixel + } + if closestIdx == -1 || distance < closestDistance { closestIdx = idx closestDistance = distance @@ -205,6 +230,51 @@ func main() { } } + // Apply non-maximal suppression over groups. + // We prefer longer groups, or if they are the same length, the group with higher + // last score. + log.Println("applying non-maximal suppression") + nmsIndex := common.NewGridIndex(*nmsDistance * 5) + for groupIdx, group := range groups { + last := group[len(group)-1] + coordinates := last.Geometry.Coordinates + nmsIndex.Insert(groupIdx, common.Point{coordinates[0], coordinates[1]}.Rectangle()) + } + var newGroups []Group + for groupIdx, group := range groups { + last := group[len(group)-1] + coordinates := last.Geometry.Coordinates + results := nmsIndex.Search(common.Point{coordinates[0], coordinates[1]}.RectangleTol(*nmsDistance)) + needsRemoval := false + for _, otherIdx := range results { + if otherIdx == groupIdx { + continue + } + other := groups[otherIdx] + otherLast := other[len(other)-1] + otherCoordinates := otherLast.Geometry.Coordinates + dx := coordinates[0] - otherCoordinates[0] + dy := coordinates[1] - otherCoordinates[1] + distance := math.Sqrt(float64(dx*dx + dy*dy)) + if distance >= *nmsDistance { + continue + } + + // It is within distance threshold, so see if group is worse than other. + if len(group) < len(other) { + needsRemoval = true + } else if len(group) == len(other) && *last.Properties.Score < *otherLast.Properties.Score { + needsRemoval = true + } + } + + if !needsRemoval { + newGroups = append(newGroups, group) + } + } + log.Printf("NMS filtered from %d to %d groups", len(groups), len(newGroups)) + groups = newGroups + // Apply Viterbi algorithm in each group. initialProbs := []float64{0.5, 0.5} transitionProbs := [][]float64{ @@ -318,6 +388,10 @@ func main() { validLabelSet[label] = true } + if len(validLabelSet) < MIN_VALID_TIMESTEPS { + continue + } + // Now make history of observations for Viterbi algorithm. // We only include timesteps where the tile was valid. // We also create a map from observed timesteps to original timestep index. @@ -363,6 +437,7 @@ func main() { for _, rng := range curRngs { last := rng.Group[len(rng.Group)-1] feat := Point{} + feat.Type = "Feature" feat.Geometry = last.Geometry feat.Properties.Category = last.Properties.Category feat.Properties.Score = last.Properties.Score From 7085a2fcd3e0c642aaa62ec0ee4b2a1a303199b0 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 12 Dec 2024 14:47:40 -0800 Subject: [PATCH 20/58] sync --- .../wind_turbine/config.json | 36 ++- data/satlas/marine_infra/config.yaml | 2 +- data/satlas/marine_infra/config_20241002.yaml | 210 +++++++++++++ data/satlas/marine_infra/config_20241030.yaml | 219 +++++++++++++ .../marine_infra/config_20241030_3image.yaml | 210 +++++++++++++ .../marine_infra/config_20241030_infer.yaml | 219 +++++++++++++ data/satlas/marine_infra/config_20241210.json | 111 +++++++ data/satlas/marine_infra/config_20241210.yaml | 219 +++++++++++++ data/satlas/wind_turbine/config.json | 73 +++++ data/satlas/wind_turbine/config.yaml | 238 +++++++++++++++ data/satlas/wind_turbine/config_azure.json | 98 ++++++ requirements.txt | 2 + rslp/common/__init__.py | 3 +- rslp/common/worker.py | 250 +++++++++++---- rslp/satlas/__init__.py | 3 +- rslp/satlas/data_sources.py | 288 +++++++++++++++--- rslp/satlas/job_launcher.py | 12 +- rslp/satlas/job_launcher_worker.py | 143 +++------ rslp/satlas/postprocess.py | 5 +- 19 files changed, 2115 insertions(+), 226 deletions(-) create mode 100644 data/satlas/marine_infra/config_20241002.yaml create mode 100644 data/satlas/marine_infra/config_20241030.yaml create mode 100644 data/satlas/marine_infra/config_20241030_3image.yaml create mode 100644 data/satlas/marine_infra/config_20241030_infer.yaml create mode 100644 data/satlas/marine_infra/config_20241210.json create mode 100644 data/satlas/marine_infra/config_20241210.yaml create mode 100644 data/satlas/wind_turbine/config.json create mode 100644 data/satlas/wind_turbine/config.yaml create mode 100644 data/satlas/wind_turbine/config_azure.json diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json index 047907a7..d68153c7 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json +++ b/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json @@ -21,7 +21,7 @@ "output": { "type": "vector" }, - "sentinel2": { + "sentinel2_a": { "band_sets": [ { "bands": [ @@ -68,7 +68,8 @@ }, "type": "raster" }, - "sentinel2.1": { + "sentinel2_b": { + "alias": "sentinel2", "band_sets": [ { "bands": [ @@ -101,9 +102,24 @@ "zoom_offset": -2 } ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "1d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "query_config": { + "max_matches": 2, + "space_mode": "CONTAINS" + }, + "sort_by": "cloud_cover", + "time_offset": "-90d", + "use_rtree_index": false + }, "type": "raster" }, - "sentinel2.2": { + "sentinel2_c": { + "alias": "sentinel2", "band_sets": [ { "bands": [ @@ -136,6 +152,20 @@ "zoom_offset": -2 } ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "1d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "query_config": { + "max_matches": 2, + "space_mode": "CONTAINS" + }, + "sort_by": "cloud_cover", + "time_offset": "-180d", + "use_rtree_index": false + }, "type": "raster" } }, diff --git a/data/satlas/marine_infra/config.yaml b/data/satlas/marine_infra/config.yaml index 93d2f2f9..4be197a0 100644 --- a/data/satlas/marine_infra/config.yaml +++ b/data/satlas/marine_infra/config.yaml @@ -216,4 +216,4 @@ trainer: module_selector: ["model", "encoder", 0, "encoder", "model"] unfreeze_at_epoch: 2 rslp_project: satlas_marine_infra -rslp_experiment: data_20241030_satlaspretrainold_patch512_00 +rslp_experiment: data_20241030_run_20241210_00 diff --git a/data/satlas/marine_infra/config_20241002.yaml b/data/satlas/marine_infra/config_20241002.yaml new file mode 100644 index 00000000..d63af8d5 --- /dev/null +++ b/data/satlas/marine_infra/config_20241002.yaml @@ -0,0 +1,210 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 3 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslp.satlas.train.MarineInfraTask + init_args: + property_name: "category" + classes: ["unknown", "platform", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] + skip_unknown_categories: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + tags: + split: train + val_config: + patch_size: 512 + tags: + split: val + test_config: + patch_size: 512 + tags: + split: val + predict_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_marine_infra +rslp_experiment: data_20241002_run_20241210_00 diff --git a/data/satlas/marine_infra/config_20241030.yaml b/data/satlas/marine_infra/config_20241030.yaml new file mode 100644 index 00000000..9ebc4ab6 --- /dev/null +++ b/data/satlas/marine_infra/config_20241030.yaml @@ -0,0 +1,219 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 3 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2_a"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2_b"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2_c"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2_d"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslp.satlas.train.MarineInfraTask + init_args: + property_name: "category" + classes: ["unknown", "platform", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] + skip_unknown_categories: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + tags: + split: train + val_config: + patch_size: 512 + tags: + split: val + test_config: + patch_size: 512 + tags: + split: val + predict_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_marine_infra +rslp_experiment: data_20241030_run_20241210_00 diff --git a/data/satlas/marine_infra/config_20241030_3image.yaml b/data/satlas/marine_infra/config_20241030_3image.yaml new file mode 100644 index 00000000..507016a3 --- /dev/null +++ b/data/satlas/marine_infra/config_20241030_3image.yaml @@ -0,0 +1,210 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 3 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2_a"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2_b"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2_c"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslp.satlas.train.MarineInfraTask + init_args: + property_name: "category" + classes: ["unknown", "platform", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] + skip_unknown_categories: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + tags: + split: train + val_config: + patch_size: 512 + tags: + split: val + test_config: + patch_size: 512 + tags: + split: val + predict_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + output_selector: image + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_marine_infra +rslp_experiment: data_20241030_run_20241210_3image_00 diff --git a/data/satlas/marine_infra/config_20241030_infer.yaml b/data/satlas/marine_infra/config_20241030_infer.yaml new file mode 100644 index 00000000..93d2f2f9 --- /dev/null +++ b/data/satlas/marine_infra/config_20241030_infer.yaml @@ -0,0 +1,219 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 3 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslp.satlas.train.MarineInfraTask + init_args: + property_name: "category" + classes: ["unknown", "platform", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] + skip_unknown_categories: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 4 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + tags: + split: train + val_config: + patch_size: 512 + tags: + split: val + test_config: + patch_size: 512 + tags: + split: val + predict_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_marine_infra +rslp_experiment: data_20241030_satlaspretrainold_patch512_00 diff --git a/data/satlas/marine_infra/config_20241210.json b/data/satlas/marine_infra/config_20241210.json new file mode 100644 index 00000000..bbb765a5 --- /dev/null +++ b/data/satlas/marine_infra/config_20241210.json @@ -0,0 +1,111 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "format": { + "coordinate_mode": "pixel", + "name": "geojson" + }, + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + } + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_cloud_cover": 50, + "max_time_delta": "0d", + "modality": "L1C", + "name": "rslp.satlas.data_sources.MonthlySentinel2", + "query_config": { + "max_matches": 4 + }, + "sort_by": "cloud_cover", + "use_rtree_index": false + }, + "type": "raster" + } + }, + "tile_store": { + "class_path": "rslearn.tile_stores.default.DefaultTileStore", + "init_args": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + } + } + } +} diff --git a/data/satlas/marine_infra/config_20241210.yaml b/data/satlas/marine_infra/config_20241210.yaml new file mode 100644 index 00000000..90c50357 --- /dev/null +++ b/data/satlas/marine_infra/config_20241210.yaml @@ -0,0 +1,219 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 3 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslp.satlas.train.MarineInfraTask + init_args: + property_name: "category" + classes: ["unknown", "platform", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] + skip_unknown_categories: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + tags: + split: train + val_config: + patch_size: 512 + tags: + split: val + test_config: + patch_size: 512 + tags: + split: val + predict_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_marine_infra +rslp_experiment: data_20241210_run_20241210_00 diff --git a/data/satlas/wind_turbine/config.json b/data/satlas/wind_turbine/config.json new file mode 100644 index 00000000..da90e09f --- /dev/null +++ b/data/satlas/wind_turbine/config.json @@ -0,0 +1,73 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_cloud_cover": 50, + "max_time_delta": "0d", + "modality": "L1C", + "name": "rslp.satlas.data_sources.MonthlySentinel2", + "query_config": { + "max_matches": 6 + }, + "sort_by": "cloud_cover", + "use_rtree_index": false + }, + "type": "raster" + } + } +} diff --git a/data/satlas/wind_turbine/config.yaml b/data/satlas/wind_turbine/config.yaml new file mode 100644 index 00000000..28a0aead --- /dev/null +++ b/data/satlas/wind_turbine/config.yaml @@ -0,0 +1,238 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241002_satlaspretrainold_patch384_03 diff --git a/data/satlas/wind_turbine/config_azure.json b/data/satlas/wind_turbine/config_azure.json new file mode 100644 index 00000000..68df7aed --- /dev/null +++ b/data/satlas/wind_turbine/config_azure.json @@ -0,0 +1,98 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, + "sentinel1": { + "band_sets": [ + { + "bands": [ + "vv", + "vh" + ], + "dtype": "float32" + } + ], + "data_source": { + "name": "rslp.satlas.data_sources.MonthlySentinel1", + "query_config": { + "max_matches": 6 + } + }, + "type": "raster" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "max_cloud_cover": 50, + "name": "rslp.satlas.data_sources.MonthlyAzureSentinel2", + "query_config": { + "max_matches": 6 + }, + "sort_by": "eo:cloud_cover" + }, + "type": "raster" + } + }, + "tile_store": { + "class_path": "rslearn.tile_stores.default.DefaultTileStore", + "init_args": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "path_suffix": "gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/tiles" + } + } +} diff --git a/requirements.txt b/requirements.txt index 8169b9aa..f692edbf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,12 @@ beaker-py>=1.32 fastapi>=0.115 google-cloud-bigtable>=2.18 +google-cloud-pubsub>=2.18 interrogate>=1.7 pydantic>=2.8 pytest>=8.2 python-dotenv>=1.0 ruff>=0.7 +scikit-image>=0.23 typing-extensions>=4.11 uvicorn>=0.32 diff --git a/rslp/common/__init__.py b/rslp/common/__init__.py index 6cbbd859..d3b0fdc9 100644 --- a/rslp/common/__init__.py +++ b/rslp/common/__init__.py @@ -1,7 +1,8 @@ """Pipelines common across projects.""" -from .worker import worker_pipeline +from .worker import launch_workers, worker_pipeline workflows = { "worker": worker_pipeline, + "launch": launch_workers, } diff --git a/rslp/common/worker.py b/rslp/common/worker.py index dbb36016..bb433842 100644 --- a/rslp/common/worker.py +++ b/rslp/common/worker.py @@ -1,13 +1,24 @@ """Worker to process jobs in a list of jobs.""" import json -import random +import threading +import time +import uuid from datetime import datetime, timedelta, timezone -from google.api_core.exceptions import PreconditionFailed -from google.cloud import storage -from upath import UPath - +import tqdm +from beaker import ( + Beaker, + Constraints, + DataMount, + DataSource, + ExperimentSpec, + Priority, + TaskResources, +) +from google.cloud import pubsub_v1, storage + +from rslp.launch_beaker import BUDGET, DEFAULT_WORKSPACE, IMAGE_NAME, get_base_env_vars from rslp.log_utils import get_logger from rslp.main import run_workflow @@ -52,70 +63,181 @@ def _get_pending_jobs( def worker_pipeline( - project: str, - workflow: str, - job_fname: str, - claim_bucket_name: str, - claim_dir: str, + project_id: str, + subscription_id: str, + retries: int = 3, + retry_sleep: int = 60, + idle_timeout: int = 10, ) -> None: - """Start a worker to run the specified jobs. + """Start a worker to run jobs from a Pub/Sub subscription. + + The job dict including rslp project, workflow, and arguments to pass must be + written to the topic. Args: - project: the project that the workflow to run is in. - workflow: the workflow to run. - job_fname: file containing the full list of jobs (arguments to the workflow - function) that need to be run. - claim_bucket_name: the GCS bucket to use for claiming jobs. - claim_dir: the path within claim_bucket_name to use for claiming jobs. + project_id: the Google Cloud project ID. + subscription_id: the Pub/Sub subscription ID. + retries: retry for this many consecutive errors before terminating. A "retry" + may run a different job than the one that originally caused failure. This + ensures workers will complete most of the jobs before they terminate due to + errors. + retry_sleep: sleep for this many seconds between retries. Sleeping helps in + case there is an error due to rate limiting. + idle_timeout: seconds before we terminate if there is no activity. """ - job_upath = UPath(job_fname) - client = storage.Client() - claim_bucket = client.bucket(claim_bucket_name) - - with job_upath.open("r") as f: - jobs: list[list[str]] = json.load(f) - - # Get the currently pending jobs. - # Our strategy will be to sample a job and attempt to claim it. - # And then if the claim fails then we refresh the pending jobs. - # This works for up to ~10000 jobs. - pending = _get_pending_jobs(jobs, claim_bucket, claim_dir) - - while len(pending) > 0: - job_idx = random.choice(pending) - pending.remove(job_idx) - pending_blob = claim_bucket.blob(f"{claim_dir}pending/{job_idx}") - completed_blob = claim_bucket.blob(f"{claim_dir}completed/{job_idx}") - - # Determine the generation of pending_blob so we can create a newer one if - # applicable. If it doesn't exist, we use 0 so that it will throw error if the - # file exists at all (the actual generation should never be 0). - pending_blob_generation = 0 - is_pending = False - if pending_blob.exists(): - pending_blob.reload() - pending_blob_generation = pending_blob.generation - if datetime.now(timezone.utc) - pending_blob.time_created < timedelta( - hours=MAX_JOB_HOURS - ): - is_pending = True - - if is_pending or completed_blob.exists(): - pending = _get_pending_jobs(jobs, claim_bucket, claim_dir) - continue - + subscriber = pubsub_v1.SubscriberClient() + subscription_path = subscriber.subscription_path(project_id, subscription_id) + + # Callback to run the workflow indicated in the message. + def process_message(message: pubsub_v1.subscriber.message.Message) -> None: + logger.debug("worker received message %s", message) + json_data = json.loads(message.data.decode()) + rslp_project = json_data["project"] + rslp_workflow = json_data["workflow"] + workflow_args = json_data["args"] + run_workflow(rslp_project, rslp_workflow, workflow_args) + + # Callback that wraps process_message to keep track of: + # 1. Whether a message is currently being processed. + # 2. The last time that a message finished processing. + # 3. The number of consecutive errors. If there is an error in process_message, it + # will sleep for retry_sleep unless it exceeds retries in which case we exit. + lock = threading.Lock() + is_processing = False + last_message_time = time.time() + consecutive_errors = 0 + + def callback(message: pubsub_v1.subscriber.message.Message) -> None: + nonlocal is_processing, last_message_time, consecutive_errors try: - # Use generation so that it throws error if generation doesn't match. - pending_blob.upload_from_string( - "", if_generation_match=pending_blob_generation + with lock: + is_processing = True + + process_message(message) + message.ack() + + with lock: + consecutive_errors = 0 + except Exception as e: + logger.error( + "encountered error while processing message %s: %s (%d/%d consecutive errors)", + message, + e, + consecutive_errors, + retries, ) - except PreconditionFailed: - # This means another worker claimed the job in between when we confirmed - # the blob doesn't exist already and when we tried to claim it. In this - # case we just try again. - continue + with lock: + consecutive_errors += 1 + time.sleep(retry_sleep) + # Pub/Sub will catch this error and print it so we just re-raise it here. + # But in our monitoring loop below we will check for more errors than + # retries and cancel the subscription if so. + raise + finally: + with lock: + is_processing = False + last_message_time = time.time() + + flow_control = pubsub_v1.types.FlowControl( + max_messages=1, + ) + streaming_pull_future = subscriber.subscribe( + subscription_path, callback=callback, flow_control=flow_control + ) + logger.info("worker listening for messages on %s", subscription_path) + try: + while True: + time.sleep(idle_timeout) + + with lock: + if consecutive_errors > retries: + logger.info( + "worker exiting due to %d consecutive errors", + consecutive_errors, + ) + break + + if is_processing: + logger.debug( + "worker continuing since a message is currently being processed" + ) + continue + + time_since_last_activity = time.time() - last_message_time + if time_since_last_activity < idle_timeout: + logger.debug( + "worker continuing since time since last activity %d is less than idle timeout %d", + time_since_last_activity, + idle_timeout, + ) + continue + + logger.info("worker exiting due to idle timeout") + break + finally: + streaming_pull_future.cancel() + streaming_pull_future.result() + + +def launch_worker(project_id: str, subscription_id: str) -> None: + """Launch a worker job. - logger.info("claimed job %d and running it now", job_idx) - run_workflow(project, workflow, jobs[job_idx]) + Args: + project_id: the Google Cloud project ID. + subscription_id: the Pub/Sub subscription ID. + """ - completed_blob.upload_from_string("") + +def launch_workers( + project_id: str, + subscription_id: str, + num_workers: int, + gpus: int = 0, + shared_memory: str | None = None, + priority: Priority = Priority.low, + cluster: list[str] = ["ai2/augusta-google-1"], +) -> None: + """Start workers for the prediction jobs. + + Args: + project_id: the Google Cloud project ID. + subscription_id: the Pub/Sub subscription ID. + num_workers: number of workers to launch + gpus: number of GPUs to request per worker. + shared_memory: shared memory string like "256GiB". + priority: priority to assign the Beaker jobs. + cluster: clusters to target. + """ + beaker = Beaker.from_env(default_workspace=DEFAULT_WORKSPACE) + + with beaker.session(): + for _ in tqdm.tqdm(range(num_workers)): + env_vars = get_base_env_vars(use_weka_prefix=False) + + spec = ExperimentSpec.new( + budget=BUDGET, + description="worker", + beaker_image=IMAGE_NAME, + priority=priority, + command=["python", "-m", "rslp.main"], + arguments=[ + "common", + "worker", + project_id, + subscription_id, + ], + constraints=Constraints( + cluster=cluster, + ), + preemptible=True, + datasets=[ + DataMount( + source=DataSource(secret="RSLEARN_GCP_CREDENTIALS"), # nosec + mount_path="/etc/credentials/gcp_credentials.json", # nosec + ), + ], + env_vars=env_vars, + resources=TaskResources(gpu_count=gpus, shared_memory=shared_memory), + ) + unique_id = str(uuid.uuid4())[0:8] + beaker.experiment.create(f"worker_{unique_id}", spec) diff --git a/rslp/satlas/__init__.py b/rslp/satlas/__init__.py index f0439050..d595e960 100644 --- a/rslp/satlas/__init__.py +++ b/rslp/satlas/__init__.py @@ -7,7 +7,7 @@ - Tree cover """ -from .job_launcher_worker import launch_workers, write_jobs, write_jobs_for_year_months +from .job_launcher_worker import write_jobs, write_jobs_for_year_months from .postprocess import merge_points, smooth_points from .predict_pipeline import predict_multi, predict_pipeline from .publish import publish_points @@ -17,7 +17,6 @@ "predict_multi": predict_multi, "write_jobs": write_jobs, "write_jobs_for_year_months": write_jobs_for_year_months, - "launch_workers": launch_workers, "merge_points": merge_points, "smooth_points": smooth_points, "publish_points": publish_points, diff --git a/rslp/satlas/data_sources.py b/rslp/satlas/data_sources.py index fb8cdcd8..9285631c 100644 --- a/rslp/satlas/data_sources.py +++ b/rslp/satlas/data_sources.py @@ -3,16 +3,61 @@ from datetime import timedelta from typing import Any +import shapely from rslearn.config import QueryConfig, RasterLayerConfig, SpaceMode from rslearn.const import WGS84_PROJECTION +from rslearn.data_sources.azure_sentinel1 import Sentinel1 +from rslearn.data_sources.azure_sentinel2 import Sentinel2 as AzureSentinel2 from rslearn.data_sources.data_source import DataSource, Item -from rslearn.data_sources.gcp_public_data import Sentinel2, Sentinel2Item +from rslearn.data_sources.gcp_public_data import Sentinel2 as GcpSentinel2 from rslearn.data_sources.utils import match_candidate_items_to_window from rslearn.tile_stores import TileStore from rslearn.utils.geometry import STGeometry from upath import UPath +def _find_monthly_matches( + geometry: STGeometry, item_list: list[Item], period_days: int, max_matches: int +) -> list[list[Item]]: + # Find matches across the periods. + # For each period, we create an STGeometry with modified time range + # matching the period, and obtain matching mosaic. + # We start from the end of the time range because we care more about recent + # periods and so we want to make sure that they align correctly with the + # end. + cur_groups: list[list[Item]] = [] + period_end = geometry.time_range[1] + while period_end > geometry.time_range[0] and len(cur_groups) < max_matches: + period_time_range = ( + period_end - timedelta(days=period_days), + period_end, + ) + period_end -= timedelta(period_days) + period_geom = STGeometry(geometry.projection, geometry.shp, period_time_range) + + # We modify the QueryConfig here since caller should be asking for + # multiple mosaics, but we just want one mosaic per period. + period_groups = match_candidate_items_to_window( + period_geom, + item_list, + QueryConfig(space_mode=SpaceMode.MOSAIC, max_matches=1), + ) + + # There should be zero on one groups depending on whether there were + # any items that matched. We keep the group if it is there. + if len(period_groups) == 0 or len(period_groups[0]) == 0: + # No matches for this period. + continue + cur_groups.append(period_groups[0]) + + # If there are not enough matching mosaics, then we eliminate all the + # matches since we aren't going to use this window then anyway. + if len(cur_groups) < max_matches: + return [] + + return cur_groups + + class MonthlySentinel2(DataSource): """Sentinel2 data source where each match is a mosaic from a different month. @@ -25,7 +70,7 @@ class MonthlySentinel2(DataSource): def __init__( self, - sentinel2: Sentinel2, + sentinel2: GcpSentinel2, max_cloud_cover: float | None = None, period_days: int = 30, ): @@ -44,7 +89,7 @@ def __init__( @staticmethod def from_config(config: RasterLayerConfig, ds_path: UPath) -> "MonthlySentinel2": """Creates a new MonthlySentinel2 instance from a configuration dictionary.""" - sentinel2 = Sentinel2.from_config(config, ds_path) + sentinel2 = GcpSentinel2.from_config(config, ds_path) kwargs = {} d = config.data_source.config_dict for k in ["max_cloud_cover", "period_days"]: @@ -53,13 +98,13 @@ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "MonthlySentinel2" kwargs[k] = d[k] return MonthlySentinel2(sentinel2, **kwargs) - def deserialize_item(self, serialized_item: Any) -> Sentinel2Item: + def deserialize_item(self, serialized_item: Any) -> Item: """Deserializes an item from JSON-decoded data.""" return self.sentinel2.deserialize_item(serialized_item) def get_items( self, geometries: list[STGeometry], query_config: QueryConfig - ) -> list[list[list[Sentinel2Item]]]: + ) -> list[list[list[Item]]]: """Get a list of items in the data source intersecting the given geometries. Args: @@ -95,47 +140,109 @@ def get_items( if item.cloud_cover <= self.max_cloud_cover ] - # Find matches across the periods. - # For each period, we create an STGeometry with modified time range - # matching the period, and obtain matching mosaic. - # We start from the end of the time range because we care more about recent - # periods and so we want to make sure that they align correctly with the - # end. - cur_groups: list[Item] = [] - period_end = geometry.time_range[1] - while ( - period_end > geometry.time_range[0] - and len(cur_groups) < query_config.max_matches - ): - period_time_range = ( - period_end - timedelta(days=self.period_days), - period_end, - ) - period_end -= timedelta(self.period_days) - period_geom = STGeometry( - geometry.projection, geometry.shp, period_time_range - ) + cur_groups = _find_monthly_matches( + geometry=geometry, + item_list=item_list, + period_days=self.period_days, + max_matches=query_config.max_matches, + ) + groups.append(cur_groups) + + return groups + + def ingest( + self, + tile_store: TileStore, + items: list[Item], + geometries: list[list[STGeometry]], + ) -> None: + """Ingest items into the given tile store. + + Args: + tile_store: the tile store to ingest into + items: the items to ingest + geometries: a list of geometries needed for each item + """ + self.sentinel2.ingest(tile_store, items, geometries) - # We modify the QueryConfig here since caller should be asking for - # multiple mosaics, but we just want one mosaic per period. - period_groups = match_candidate_items_to_window( - period_geom, - item_list, - QueryConfig(space_mode=SpaceMode.MOSAIC, max_matches=1), - ) - # There should be zero on one groups depending on whether there were - # any items that matched. We keep the group if it is there. - if len(period_groups) == 0 or len(period_groups[0]) == 0: - # No matches for this period. - continue - cur_groups.append(period_groups[0]) +class MonthlyAzureSentinel2(DataSource): + """Similar to MonthlySentinel2 but for Sentinel-2 L2A on Azure.""" - # If there are not enough matching mosaics, then we eliminate all the - # matches since we aren't going to use this window then anyway. - if len(cur_groups) < query_config.max_matches: - cur_groups = [] + def __init__( + self, + sentinel2: AzureSentinel2, + period_days: int = 30, + ): + """Create a new MonthlyAzureSentinel2. + Args: + sentinel2: the Sentinel2 data source to wrap. + period_days: create mosaics for intervals of this many days within the + geometry time range. + """ + self.sentinel2 = sentinel2 + self.period_days = period_days + + @staticmethod + def from_config( + config: RasterLayerConfig, ds_path: UPath + ) -> "MonthlyAzureSentinel2": + """Creates a new MonthlyAzureSentinel2 instance from a configuration dictionary.""" + sentinel2 = AzureSentinel2.from_config(config, ds_path) + kwargs = {} + d = config.data_source.config_dict + for k in ["period_days"]: + if k not in d: + continue + kwargs[k] = d[k] + return MonthlyAzureSentinel2(sentinel2, **kwargs) + + def deserialize_item(self, serialized_item: Any) -> Item: + """Deserializes an item from JSON-decoded data.""" + return self.sentinel2.deserialize_item(serialized_item) + + def get_items( + self, geometries: list[STGeometry], query_config: QueryConfig + ) -> list[list[list[Item]]]: + """Get a list of items in the data source intersecting the given geometries. + + Args: + geometries: the spatiotemporal geometries + query_config: the query configuration + + Returns: + List of groups of items that should be retrieved for each geometry. + """ + # This only makes sense for mosaic space mode. + assert query_config.space_mode == SpaceMode.MOSAIC + + groups = [] + for geometry in geometries: + # This part is the same as in base Sentinel2 class. + wgs84_geometry = geometry.to_projection(WGS84_PROJECTION) + result = self.sentinel2.client.search( + collections=[self.sentinel2.COLLECTION_NAME], + intersects=shapely.to_geojson(wgs84_geometry.shp), + datetime=wgs84_geometry.time_range, + query=self.sentinel2.query, + ) + stac_items = [item for item in result.item_collection()] + + if self.sentinel2.sort_by is not None: + stac_items.sort( + key=lambda stac_item: stac_item.properties[self.sentinel2.sort_by], + reverse=not self.sentinel2.sort_ascending, + ) + + candidate_items = [ + self.sentinel2._stac_item_to_item(stac_item) for stac_item in stac_items + ] + + # Now we use _find_monthly_matches. + cur_groups = _find_monthly_matches( + geometry, candidate_items, self.period_days, query_config.max_matches + ) groups.append(cur_groups) return groups @@ -143,7 +250,7 @@ def get_items( def ingest( self, tile_store: TileStore, - items: list[Sentinel2Item], + items: list[Item], geometries: list[list[STGeometry]], ) -> None: """Ingest items into the given tile store. @@ -154,3 +261,98 @@ def ingest( geometries: a list of geometries needed for each item """ self.sentinel2.ingest(tile_store, items, geometries) + + +class MonthlySentinel1(DataSource): + """Similar to MonthlySentinel2 but for Sentinel-1 on Azure.""" + + def __init__( + self, + sentinel1: Sentinel1, + period_days: int = 30, + ): + """Create a new MonthlySentinel1. + + Args: + sentinel1: the Sentinel1 data source to wrap. + period_days: create mosaics for intervals of this many days within the + geometry time range. + """ + self.sentinel1 = sentinel1 + self.period_days = period_days + + @staticmethod + def from_config(config: RasterLayerConfig, ds_path: UPath) -> "MonthlySentinel1": + """Creates a new MonthlySentinel1 instance from a configuration dictionary.""" + sentinel1 = Sentinel1.from_config(config, ds_path) + kwargs = {} + d = config.data_source.config_dict + for k in ["period_days"]: + if k not in d: + continue + kwargs[k] = d[k] + return MonthlySentinel1(sentinel1, **kwargs) + + def deserialize_item(self, serialized_item: Any) -> Item: + """Deserializes an item from JSON-decoded data.""" + return self.sentinel1.deserialize_item(serialized_item) + + def get_items( + self, geometries: list[STGeometry], query_config: QueryConfig + ) -> list[list[list[Item]]]: + """Get a list of items in the data source intersecting the given geometries. + + Args: + geometries: the spatiotemporal geometries + query_config: the query configuration + + Returns: + List of groups of items that should be retrieved for each geometry. + """ + # This only makes sense for mosaic space mode. + assert query_config.space_mode == SpaceMode.MOSAIC + + groups = [] + for geometry in geometries: + # This part is the same as in base Sentinel1 class. + wgs84_geometry = geometry.to_projection(WGS84_PROJECTION) + result = self.sentinel1.client.search( + collections=[self.sentinel1.COLLECTION_NAME], + intersects=shapely.to_geojson(wgs84_geometry.shp), + datetime=wgs84_geometry.time_range, + query=self.sentinel1.query, + ) + stac_items = [item for item in result.item_collection()] + + if self.sentinel1.sort_by is not None: + stac_items.sort( + key=lambda stac_item: stac_item.properties[self.sentinel1.sort_by], + reverse=not self.sentinel1.sort_ascending, + ) + + candidate_items = [ + self.sentinel1._stac_item_to_item(stac_item) for stac_item in stac_items + ] + + # Now we use _find_monthly_matches. + cur_groups = _find_monthly_matches( + geometry, candidate_items, self.period_days, query_config.max_matches + ) + groups.append(cur_groups) + + return groups + + def ingest( + self, + tile_store: TileStore, + items: list[Item], + geometries: list[list[STGeometry]], + ) -> None: + """Ingest items into the given tile store. + + Args: + tile_store: the tile store to ingest into + items: the items to ingest + geometries: a list of geometries needed for each item + """ + self.sentinel1.ingest(tile_store, items, geometries) diff --git a/rslp/satlas/job_launcher.py b/rslp/satlas/job_launcher.py index 3e8fd5b8..b1d8b332 100644 --- a/rslp/satlas/job_launcher.py +++ b/rslp/satlas/job_launcher.py @@ -31,10 +31,10 @@ RESOLUTION = 10 # Days to add before a provided date. -DAYS_BEFORE = 120 +DEFAULT_DAYS_BEFORE = 120 # Days to add after a provided date. -DAYS_AFTER = 90 +DEFAULT_DAYS_AFTER = 90 class Task: @@ -283,6 +283,8 @@ def launch_jobs_for_year_month( out_path: str, batch_size: int = 1, count: int | None = None, + days_before: int = DEFAULT_DAYS_BEFORE, + days_after: int = DEFAULT_DAYS_AFTER, ) -> None: """Launch Satlas prediction jobs on Beaker for the given year and month. @@ -293,11 +295,13 @@ def launch_jobs_for_year_month( out_path: the output path with year and month placeholders. batch_size: the batch size. count: only run up to this many tasks. + days_before: how much to pad windows before the year/month. + days_after: how much to pad windows after the year/month. """ ts = datetime(year, month, 1, tzinfo=timezone.utc) time_range = ( - ts - timedelta(days=DAYS_BEFORE), - ts + timedelta(days=DAYS_AFTER), + ts - timedelta(days=days_before), + ts + timedelta(days=days_after), ) cur_out_path = out_path.format(year=year, month=month) print(f"launching jobs with time_range={time_range} and out_path={cur_out_path}") diff --git a/rslp/satlas/job_launcher_worker.py b/rslp/satlas/job_launcher_worker.py index 2ded9d9e..fc39f373 100644 --- a/rslp/satlas/job_launcher_worker.py +++ b/rslp/satlas/job_launcher_worker.py @@ -1,28 +1,16 @@ """Launch Satlas prediction jobs on Beaker.""" import json -import uuid from datetime import datetime, timedelta, timezone import shapely import tqdm -from beaker import ( - Beaker, - Constraints, - DataMount, - DataSource, - EnvVar, - ExperimentSpec, - Priority, - TaskResources, -) +from google.cloud import pubsub_v1 from rasterio.crs import CRS from rslearn.const import WGS84_PROJECTION from rslearn.utils.geometry import PixelBounds, Projection, STGeometry from rslearn.utils.get_utm_ups_crs import get_proj_bounds -from upath import UPath -from rslp.launch_beaker import BUDGET, DEFAULT_WORKSPACE, IMAGE_NAME, get_base_env_vars from rslp.log_utils import get_logger from .predict_pipeline import Application, PredictTaskArgs @@ -33,10 +21,10 @@ RESOLUTION = 10 # Days to add before a provided date. -DAYS_BEFORE = 120 +DEFAULT_DAYS_BEFORE = 120 # Days to add after a provided date. -DAYS_AFTER = 90 +DEFAULT_DAYS_AFTER = 90 class Task: @@ -66,76 +54,6 @@ def __init__( self.out_path = out_path -class WorkerParams: - """Parameters that worker pipeline needs to know.""" - - def __init__(self, job_fname: str, claim_bucket_name: str, claim_dir: str) -> None: - """Create a new WorkerParams. - - Args: - job_fname: the filename containing list of jobs. - claim_bucket_name: the bucket where workers will claim jobs. - claim_dir: the path in the bucket to write claim files. - """ - self.job_fname = job_fname - self.claim_bucket_name = claim_bucket_name - self.claim_dir = claim_dir - - -def launch_worker(worker_params: WorkerParams) -> None: - """Launch a worker job. - - Args: - worker_params: the parameters to pass to the worker. - """ - beaker = Beaker.from_env(default_workspace=DEFAULT_WORKSPACE) - - with beaker.session(): - env_vars = get_base_env_vars(use_weka_prefix=False) - env_vars.append( - EnvVar( - name="RSLEARN_LOGLEVEL", - value="DEBUG", - ) - ) - - spec = ExperimentSpec.new( - budget=BUDGET, - description="worker", - beaker_image=IMAGE_NAME, - priority=Priority.low, - command=["python", "-m", "rslp.main"], - arguments=[ - "common", - "worker", - "satlas", - "predict_multi", - worker_params.job_fname, - worker_params.claim_bucket_name, - worker_params.claim_dir, - ], - constraints=Constraints( - cluster=[ - "ai2/jupiter-cirrascale-2", - "ai2/neptune-cirrascale", - "ai2/saturn-cirrascale", - "ai2/augusta-google-1", - ] - ), - preemptible=True, - datasets=[ - DataMount( - source=DataSource(secret="RSLEARN_GCP_CREDENTIALS"), # nosec - mount_path="/etc/credentials/gcp_credentials.json", # nosec - ), - ], - env_vars=env_vars, - resources=TaskResources(gpu_count=1, shared_memory="256GiB"), - ) - unique_id = str(uuid.uuid4())[0:8] - beaker.experiment.create(f"worker_{unique_id}", spec) - - def get_jobs( application: Application, time_range: tuple[datetime, datetime], @@ -252,11 +170,29 @@ def get_jobs( return jobs +def _write_jobs_to_topic( + jobs: list[list[str]], + project_id: str, + topic_id: str, +) -> None: + publisher = pubsub_v1.PublisherClient() + topic_path = publisher.topic_path(project_id, topic_id) + for job in tqdm.tqdm(jobs, desc="Writing jobs to Pub/Sub topic"): + json_data = dict( + project="satlas", + workflow="predict_multi", + args=job, + ) + data = json.dumps(json_data).encode() + publisher.publish(topic_path, data).result() + + def write_jobs( application: Application, time_range: tuple[datetime, datetime], out_path: str, - job_fname: str, + project_id: str, + topic_id: str, epsg_code: int | None = None, wgs84_bounds: tuple[float, float, float, float] | None = None, batch_size: int = 1, @@ -267,7 +203,8 @@ def write_jobs( application: which application to run. time_range: the time range to run within. Must have timezone. out_path: the output path. It should be specific to the time range. - job_fname: where to write the list of jobs for workers to read. + project_id: project containing the Pub/Sub topic. + topic_id: the Pub/Sub topic to write the jobs to. epsg_code: limit tasks to this UTM zone (specified by its EPSG code), default run in all UTM zones. wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. @@ -281,16 +218,18 @@ def write_jobs( wgs84_bounds=wgs84_bounds, batch_size=batch_size, ) - with UPath(job_fname).open("w") as f: - json.dump(jobs, f) + _write_jobs_to_topic(jobs, project_id, topic_id) def write_jobs_for_year_months( year_months: list[tuple[int, int]], application: Application, out_path: str, - job_fname: str, + project_id: str, + topic_id: str, batch_size: int = 1, + days_before: int = DEFAULT_DAYS_BEFORE, + days_after: int = DEFAULT_DAYS_AFTER, ) -> None: """Write Satlas prediction jobs for the given year and month. @@ -298,16 +237,19 @@ def write_jobs_for_year_months( year_months: list of year-month pairs. application: the application to run. out_path: the output path with year and month placeholders. - job_fname: where to write the list of jobs for workers to read. + project_id: project containing the Pub/Sub topic. + topic_id: the Pub/Sub topic to write the jobs to. worker_params: the worker parameters. batch_size: the batch size. + days_before: how much to pad windows before the year/month. + days_after: how much to pad windows after the year/month. """ jobs = [] for year, month in year_months: ts = datetime(year, month, 1, tzinfo=timezone.utc) time_range = ( - ts - timedelta(days=DAYS_BEFORE), - ts + timedelta(days=DAYS_AFTER), + ts - timedelta(days=days_before), + ts + timedelta(days=days_after), ) cur_out_path = out_path.format(year=year, month=month) logger.info( @@ -323,17 +265,4 @@ def write_jobs_for_year_months( jobs.extend(cur_jobs) logger.info("got a total of %d jobs across year-months", len(jobs)) - with UPath(job_fname).open("w") as f: - json.dump(jobs, f) - - -def launch_workers(worker_params: WorkerParams, num_workers: int) -> None: - """Start workers for the prediction jobs. - - Args: - worker_params: the parameters for the workers, including job file where the - list of jobs has been written. - num_workers: number of workers to launch - """ - for _ in tqdm.tqdm(range(num_workers)): - launch_worker(worker_params) + _write_jobs_to_topic(jobs, project_id, topic_id) diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py index 2d673682..f7986c0c 100644 --- a/rslp/satlas/postprocess.py +++ b/rslp/satlas/postprocess.py @@ -31,7 +31,10 @@ Application.MARINE_INFRA: { "platform": "offshore_platform", "turbine": "offshore_wind_turbine", - } + }, + Application.WIND_TURBINE: { + "turbine": "wind_turbine", + }, } logger = get_logger(__name__) From cf57d21f8c55acd632d73ed492db44859118d009 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 19 Dec 2024 11:10:17 -0800 Subject: [PATCH 21/58] sync --- data/satlas/marine_infra/config.yaml | 8 ++--- data/satlas/wind_turbine/config_azure.json | 5 +-- rslp/common/worker.py | 9 +++++- rslp/satlas/README.md | 10 +++++- rslp/satlas/data_sources.py | 37 +++++++++++++++++++++- rslp/satlas/job_launcher_worker.py | 13 ++++++++ 6 files changed, 73 insertions(+), 9 deletions(-) diff --git a/data/satlas/marine_infra/config.yaml b/data/satlas/marine_infra/config.yaml index 4be197a0..90c50357 100644 --- a/data/satlas/marine_infra/config.yaml +++ b/data/satlas/marine_infra/config.yaml @@ -39,7 +39,7 @@ model: data: class_path: rslearn.train.data_module.RslearnDataModule init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/ inputs: image1: data_type: "raster" @@ -99,7 +99,7 @@ data: input_mapping: detect: targets: "targets" - batch_size: 4 + batch_size: 8 num_workers: 32 default_config: transforms: @@ -202,7 +202,7 @@ trainer: logging_interval: "epoch" - class_path: rslearn.train.prediction_writer.RslearnWriter init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/ output_layer: output selector: ["detect"] - class_path: lightning.pytorch.callbacks.ModelCheckpoint @@ -216,4 +216,4 @@ trainer: module_selector: ["model", "encoder", 0, "encoder", "model"] unfreeze_at_epoch: 2 rslp_project: satlas_marine_infra -rslp_experiment: data_20241030_run_20241210_00 +rslp_experiment: data_20241210_run_20241210_00 diff --git a/data/satlas/wind_turbine/config_azure.json b/data/satlas/wind_turbine/config_azure.json index 68df7aed..3dbc34bd 100644 --- a/data/satlas/wind_turbine/config_azure.json +++ b/data/satlas/wind_turbine/config_azure.json @@ -32,6 +32,7 @@ } ], "data_source": { + "ingest": false, "name": "rslp.satlas.data_sources.MonthlySentinel1", "query_config": { "max_matches": 6 @@ -74,6 +75,7 @@ ], "data_source": { "harmonize": true, + "ingest": false, "max_cloud_cover": 50, "name": "rslp.satlas.data_sources.MonthlyAzureSentinel2", "query_config": { @@ -91,8 +93,7 @@ "compress": "zstd", "predictor": 2, "zstd_level": 1 - }, - "path_suffix": "gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/tiles" + } } } } diff --git a/rslp/common/worker.py b/rslp/common/worker.py index bb433842..7c37c163 100644 --- a/rslp/common/worker.py +++ b/rslp/common/worker.py @@ -4,6 +4,7 @@ import threading import time import uuid +from concurrent import futures from datetime import datetime, timedelta, timezone import tqdm @@ -140,9 +141,15 @@ def callback(message: pubsub_v1.subscriber.message.Message) -> None: flow_control = pubsub_v1.types.FlowControl( max_messages=1, + max_lease_duration=24 * 3600, ) + executor = futures.ThreadPoolExecutor(max_workers=5) + scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) streaming_pull_future = subscriber.subscribe( - subscription_path, callback=callback, flow_control=flow_control + subscription_path, + callback=callback, + flow_control=flow_control, + scheduler=scheduler, ) logger.info("worker listening for messages on %s", subscription_path) try: diff --git a/rslp/satlas/README.md b/rslp/satlas/README.md index 2c5c881f..57b773c0 100644 --- a/rslp/satlas/README.md +++ b/rslp/satlas/README.md @@ -2,6 +2,14 @@ Inference: - PYTHONPATH=~/rslearn:. python -m rslp.main satlas launch MARINE_INFRA '["2024-01-01T00:00:00+00:00", "2024-04-01T00:00:00+00:00"]' gs://rslearn-eai/projects/satlas/marine_infra/version-20241030/2024-01/ + python -m rslp.main satlas write_jobs_for_year_months '[[2024, 7]]' MARINE_INFRA 'gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/{year:04d}-{month:02d}/' skylight-proto-1 rslp-job-queue-favyen + +Post-processing: + +## Wind Turbine + +Inference: + + python -m rslp.main satlas write_jobs_for_year_months '[[2024, 1]]' WIND_TURBINE 'gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/{year:04d}-{month:02d}/' skylight-proto-1 rslp-job-queue-favyen --days_before 90 --days_after 181 Post-processing: diff --git a/rslp/satlas/data_sources.py b/rslp/satlas/data_sources.py index 9285631c..e4f43bf8 100644 --- a/rslp/satlas/data_sources.py +++ b/rslp/satlas/data_sources.py @@ -4,13 +4,14 @@ from typing import Any import shapely -from rslearn.config import QueryConfig, RasterLayerConfig, SpaceMode +from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig, SpaceMode from rslearn.const import WGS84_PROJECTION from rslearn.data_sources.azure_sentinel1 import Sentinel1 from rslearn.data_sources.azure_sentinel2 import Sentinel2 as AzureSentinel2 from rslearn.data_sources.data_source import DataSource, Item from rslearn.data_sources.gcp_public_data import Sentinel2 as GcpSentinel2 from rslearn.data_sources.utils import match_candidate_items_to_window +from rslearn.dataset import Window from rslearn.tile_stores import TileStore from rslearn.utils.geometry import STGeometry from upath import UPath @@ -262,6 +263,23 @@ def ingest( """ self.sentinel2.ingest(tile_store, items, geometries) + def materialize( + self, + window: Window, + item_groups: list[list[Item]], + layer_name: str, + layer_cfg: LayerConfig, + ) -> None: + """Materialize data for the window. + + Args: + window: the window to materialize + item_groups: the items from get_items + layer_name: the name of this layer + layer_cfg: the config of this layer + """ + self.sentinel2.materialize(window, item_groups, layer_name, layer_cfg) + class MonthlySentinel1(DataSource): """Similar to MonthlySentinel2 but for Sentinel-1 on Azure.""" @@ -356,3 +374,20 @@ def ingest( geometries: a list of geometries needed for each item """ self.sentinel1.ingest(tile_store, items, geometries) + + def materialize( + self, + window: Window, + item_groups: list[list[Item]], + layer_name: str, + layer_cfg: LayerConfig, + ) -> None: + """Materialize data for the window. + + Args: + window: the window to materialize + item_groups: the items from get_items + layer_name: the name of this layer + layer_cfg: the config of this layer + """ + self.sentinel1.materialize(window, item_groups, layer_name, layer_cfg) diff --git a/rslp/satlas/job_launcher_worker.py b/rslp/satlas/job_launcher_worker.py index fc39f373..ed3aa3a8 100644 --- a/rslp/satlas/job_launcher_worker.py +++ b/rslp/satlas/job_launcher_worker.py @@ -1,6 +1,7 @@ """Launch Satlas prediction jobs on Beaker.""" import json +import random from datetime import datetime, timedelta, timezone import shapely @@ -61,6 +62,7 @@ def get_jobs( epsg_code: int | None = None, wgs84_bounds: tuple[float, float, float, float] | None = None, batch_size: int = 1, + count: int | None = None, ) -> list[list[str]]: """Get batches of tasks for Satlas prediction. @@ -72,6 +74,7 @@ def get_jobs( run in all UTM zones. wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. batch_size: how many tasks to run in each batch. + count: limit to this many tasks. Returns: the list of worker tasks where each worker task @@ -143,6 +146,10 @@ def get_jobs( print(f"Got {len(tasks)} total tasks") + if count is not None and len(tasks) > count: + tasks = random.sample(tasks, count) + logger.info("Randomly sampled %d tasks", len(tasks)) + jobs = [] for i in range(0, len(tasks), batch_size): cur_tasks = tasks[i : i + batch_size] @@ -196,6 +203,7 @@ def write_jobs( epsg_code: int | None = None, wgs84_bounds: tuple[float, float, float, float] | None = None, batch_size: int = 1, + count: int | None = None, ) -> None: """Write jobs for the specified application and time range. @@ -209,6 +217,7 @@ def write_jobs( run in all UTM zones. wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. batch_size: how many tasks to run in each batch. + count: limit to this many tasks. """ jobs = get_jobs( application, @@ -217,6 +226,7 @@ def write_jobs( epsg_code=epsg_code, wgs84_bounds=wgs84_bounds, batch_size=batch_size, + count=count, ) _write_jobs_to_topic(jobs, project_id, topic_id) @@ -230,6 +240,7 @@ def write_jobs_for_year_months( batch_size: int = 1, days_before: int = DEFAULT_DAYS_BEFORE, days_after: int = DEFAULT_DAYS_AFTER, + count: int | None = None, ) -> None: """Write Satlas prediction jobs for the given year and month. @@ -243,6 +254,7 @@ def write_jobs_for_year_months( batch_size: the batch size. days_before: how much to pad windows before the year/month. days_after: how much to pad windows after the year/month. + count: limit each year-month to this many tasks. """ jobs = [] for year, month in year_months: @@ -260,6 +272,7 @@ def write_jobs_for_year_months( time_range=time_range, out_path=cur_out_path, batch_size=batch_size, + count=count, ) logger.info("got %d jobs for %04d-%02d", len(cur_jobs), year, month) jobs.extend(cur_jobs) From c193308ecf628fe4a465330197fcc86e20118bd2 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 19 Dec 2024 11:13:19 -0800 Subject: [PATCH 22/58] remove unneeded marine infra model configs --- data/satlas/marine_infra/config_20241002.yaml | 210 ----------------- data/satlas/marine_infra/config_20241030.yaml | 219 ------------------ .../marine_infra/config_20241030_3image.yaml | 210 ----------------- .../marine_infra/config_20241030_infer.yaml | 219 ------------------ data/satlas/marine_infra/config_20241210.json | 111 --------- data/satlas/marine_infra/config_20241210.yaml | 219 ------------------ 6 files changed, 1188 deletions(-) delete mode 100644 data/satlas/marine_infra/config_20241002.yaml delete mode 100644 data/satlas/marine_infra/config_20241030.yaml delete mode 100644 data/satlas/marine_infra/config_20241030_3image.yaml delete mode 100644 data/satlas/marine_infra/config_20241030_infer.yaml delete mode 100644 data/satlas/marine_infra/config_20241210.json delete mode 100644 data/satlas/marine_infra/config_20241210.yaml diff --git a/data/satlas/marine_infra/config_20241002.yaml b/data/satlas/marine_infra/config_20241002.yaml deleted file mode 100644 index d63af8d5..00000000 --- a/data/satlas/marine_infra/config_20241002.yaml +++ /dev/null @@ -1,210 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 9 - output_layers: [1, 3, 5, 7] - image_channels: 9 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 3 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2.2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: FLOAT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslp.satlas.train.MarineInfraTask - init_args: - property_name: "category" - classes: ["unknown", "platform", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] - skip_unknown_categories: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - tags: - split: train - val_config: - patch_size: 512 - tags: - split: val - test_config: - patch_size: 512 - tags: - split: val - predict_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - output_selector: image - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0, "encoder", "model"] - unfreeze_at_epoch: 2 -rslp_project: satlas_marine_infra -rslp_experiment: data_20241002_run_20241210_00 diff --git a/data/satlas/marine_infra/config_20241030.yaml b/data/satlas/marine_infra/config_20241030.yaml deleted file mode 100644 index 9ebc4ab6..00000000 --- a/data/satlas/marine_infra/config_20241030.yaml +++ /dev/null @@ -1,219 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 9 - output_layers: [1, 3, 5, 7] - image_channels: 9 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 3 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2_a"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2_b"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2_c"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image4: - data_type: "raster" - layers: ["sentinel2_d"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: FLOAT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslp.satlas.train.MarineInfraTask - init_args: - property_name: "category" - classes: ["unknown", "platform", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] - skip_unknown_categories: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - tags: - split: train - val_config: - patch_size: 512 - tags: - split: val - test_config: - patch_size: 512 - tags: - split: val - predict_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - output_selector: image - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0, "encoder", "model"] - unfreeze_at_epoch: 2 -rslp_project: satlas_marine_infra -rslp_experiment: data_20241030_run_20241210_00 diff --git a/data/satlas/marine_infra/config_20241030_3image.yaml b/data/satlas/marine_infra/config_20241030_3image.yaml deleted file mode 100644 index 507016a3..00000000 --- a/data/satlas/marine_infra/config_20241030_3image.yaml +++ /dev/null @@ -1,210 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 9 - output_layers: [1, 3, 5, 7] - image_channels: 9 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 3 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2_a"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2_b"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2_c"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: FLOAT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslp.satlas.train.MarineInfraTask - init_args: - property_name: "category" - classes: ["unknown", "platform", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] - skip_unknown_categories: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - tags: - split: train - val_config: - patch_size: 512 - tags: - split: val - test_config: - patch_size: 512 - tags: - split: val - predict_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - output_selector: image - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0, "encoder", "model"] - unfreeze_at_epoch: 2 -rslp_project: satlas_marine_infra -rslp_experiment: data_20241030_run_20241210_3image_00 diff --git a/data/satlas/marine_infra/config_20241030_infer.yaml b/data/satlas/marine_infra/config_20241030_infer.yaml deleted file mode 100644 index 93d2f2f9..00000000 --- a/data/satlas/marine_infra/config_20241030_infer.yaml +++ /dev/null @@ -1,219 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 9 - output_layers: [1, 3, 5, 7] - image_channels: 9 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 3 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2.2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image4: - data_type: "raster" - layers: ["sentinel2.3"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: FLOAT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslp.satlas.train.MarineInfraTask - init_args: - property_name: "category" - classes: ["unknown", "platform", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] - skip_unknown_categories: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 4 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - tags: - split: train - val_config: - patch_size: 512 - tags: - split: val - test_config: - patch_size: 512 - tags: - split: val - predict_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - output_selector: image - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241030/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0, "encoder", "model"] - unfreeze_at_epoch: 2 -rslp_project: satlas_marine_infra -rslp_experiment: data_20241030_satlaspretrainold_patch512_00 diff --git a/data/satlas/marine_infra/config_20241210.json b/data/satlas/marine_infra/config_20241210.json deleted file mode 100644 index bbb765a5..00000000 --- a/data/satlas/marine_infra/config_20241210.json +++ /dev/null @@ -1,111 +0,0 @@ -{ - "layers": { - "label": { - "type": "vector" - }, - "mask": { - "band_sets": [ - { - "bands": [ - "mask" - ], - "dtype": "uint8", - "format": { - "format": "png", - "name": "single_image" - } - } - ], - "type": "raster" - }, - "output": { - "format": { - "coordinate_mode": "pixel", - "name": "geojson" - }, - "type": "vector" - }, - "sentinel2": { - "band_sets": [ - { - "bands": [ - "B02", - "B03", - "B04", - "B08" - ], - "dtype": "uint16", - "format": { - "geotiff_options": { - "compress": "zstd", - "predictor": 2, - "zstd_level": 1 - }, - "name": "geotiff" - } - }, - { - "bands": [ - "B05", - "B06", - "B07", - "B8A", - "B11", - "B12" - ], - "dtype": "uint16", - "format": { - "geotiff_options": { - "compress": "zstd", - "predictor": 2, - "zstd_level": 1 - }, - "name": "geotiff" - }, - "zoom_offset": -1 - }, - { - "bands": [ - "B01", - "B09", - "B10" - ], - "dtype": "uint16", - "format": { - "geotiff_options": { - "compress": "zstd", - "predictor": 2, - "zstd_level": 1 - }, - "name": "geotiff" - }, - "zoom_offset": -2 - } - ], - "data_source": { - "harmonize": true, - "index_cache_dir": "cache/sentinel2", - "max_cloud_cover": 50, - "max_time_delta": "0d", - "modality": "L1C", - "name": "rslp.satlas.data_sources.MonthlySentinel2", - "query_config": { - "max_matches": 4 - }, - "sort_by": "cloud_cover", - "use_rtree_index": false - }, - "type": "raster" - } - }, - "tile_store": { - "class_path": "rslearn.tile_stores.default.DefaultTileStore", - "init_args": { - "geotiff_options": { - "compress": "zstd", - "predictor": 2, - "zstd_level": 1 - } - } - } -} diff --git a/data/satlas/marine_infra/config_20241210.yaml b/data/satlas/marine_infra/config_20241210.yaml deleted file mode 100644 index 90c50357..00000000 --- a/data/satlas/marine_infra/config_20241210.yaml +++ /dev/null @@ -1,219 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 9 - output_layers: [1, 3, 5, 7] - image_channels: 9 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 3 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2.2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image4: - data_type: "raster" - layers: ["sentinel2.3"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: FLOAT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslp.satlas.train.MarineInfraTask - init_args: - property_name: "category" - classes: ["unknown", "platform", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] - skip_unknown_categories: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - tags: - split: train - val_config: - patch_size: 512 - tags: - split: val - test_config: - patch_size: 512 - tags: - split: val - predict_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - output_selector: image - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0, "encoder", "model"] - unfreeze_at_epoch: 2 -rslp_project: satlas_marine_infra -rslp_experiment: data_20241210_run_20241210_00 From cba94844334a43a1debf530cb571f6346b9ac912 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 19 Dec 2024 11:13:34 -0800 Subject: [PATCH 23/58] clarify that convert_satlas_webmercator_to_rslearn is intended to be one-off --- convert_satlas_webmercator_to_rslearn/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/convert_satlas_webmercator_to_rslearn/README.md b/convert_satlas_webmercator_to_rslearn/README.md index 7258f227..18a9ffb5 100644 --- a/convert_satlas_webmercator_to_rslearn/README.md +++ b/convert_satlas_webmercator_to_rslearn/README.md @@ -4,6 +4,9 @@ About This project is for converting the various Satlas application training datasets (wind turbines, solar farms, marine infrastructure) to rslearn format. +TODO: this project should be moved to one_off_projects in the near-future since the +conversion is complete and this code does not need to be maintained anymore. + Wind Turbines ------------- From 9f94514843f443d1ef6331c12a4bfd6f2ab2db17 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 19 Dec 2024 11:15:23 -0800 Subject: [PATCH 24/58] add readme for wind turbine configs --- data/satlas/wind_turbine/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 data/satlas/wind_turbine/README.md diff --git a/data/satlas/wind_turbine/README.md b/data/satlas/wind_turbine/README.md new file mode 100644 index 00000000..f2d4700d --- /dev/null +++ b/data/satlas/wind_turbine/README.md @@ -0,0 +1,8 @@ +`config.json` and `config.yaml` contain the active dataset and model configurations for +the wind turbine point detection model. + +It uses six monthly Sentinel-2 L1C mosaics, which can be spread over up to nine months +(in case some months don't have enough cloud-free images to form a mosaic). + +`config_azure.json` is for testing a model that inputs Sentinel-2 L2A + Sentinel-1 +vv+vh sourced from Microsoft Azure. From 0f1a7db2e1a09d1f36fb29aa9fac366490090928 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 19 Dec 2024 11:22:11 -0800 Subject: [PATCH 25/58] add readme for worker system --- rslp/common/README.md | 27 ++ rslp/satlas/__init__.py | 2 +- rslp/satlas/job_launcher.py | 314 ------------------ .../{job_launcher_worker.py => write_jobs.py} | 0 4 files changed, 28 insertions(+), 315 deletions(-) create mode 100644 rslp/common/README.md delete mode 100644 rslp/satlas/job_launcher.py rename rslp/satlas/{job_launcher_worker.py => write_jobs.py} (100%) diff --git a/rslp/common/README.md b/rslp/common/README.md new file mode 100644 index 00000000..9c0be6aa --- /dev/null +++ b/rslp/common/README.md @@ -0,0 +1,27 @@ +This contains infrastructure intended to be shared across several rslp projects. + + +Worker +------ + +`worker.py` provides a system for launching Beaker jobs that run workers to execute +tasks from a queue. + +Each task specifies an rslp project, pipeline (workflow), and arguments to pass to the +pipeline. The queue is implemented with Google Cloud Pub/Sub. + +First create the topic and subscription via CLI: + + gcloud pubsub topics create --project skylight-proto-1 rslp-job-queue-YOURNAME + gcloud pubsub subscriptions create --project skylight-proto-1 rslp-job-queue-YOURNAME-sub --topic rslp-job-queue-YOURNAME + +You will then need to write code that writes tasks to the topic. +See `satlas/write_jobs.py` for an example of this. + +Then you can launch the worker. To test on one machine: + + python -m rslp.main common worker skylight-proto-1 rslp-job-queue-YOURNAME-sub + +And to launch 100 workers on Beaker: + + python -m rslp.main common launch skylight-proto-1 rslp-job-queue-YOURNAME-sub 100 --gpus 1 --shared_memory 256GiB diff --git a/rslp/satlas/__init__.py b/rslp/satlas/__init__.py index d595e960..2bb65635 100644 --- a/rslp/satlas/__init__.py +++ b/rslp/satlas/__init__.py @@ -7,10 +7,10 @@ - Tree cover """ -from .job_launcher_worker import write_jobs, write_jobs_for_year_months from .postprocess import merge_points, smooth_points from .predict_pipeline import predict_multi, predict_pipeline from .publish import publish_points +from .write_jobs import write_jobs, write_jobs_for_year_months workflows = { "predict": predict_pipeline, diff --git a/rslp/satlas/job_launcher.py b/rslp/satlas/job_launcher.py deleted file mode 100644 index b1d8b332..00000000 --- a/rslp/satlas/job_launcher.py +++ /dev/null @@ -1,314 +0,0 @@ -"""Launch Satlas prediction jobs on Beaker.""" - -import json -import multiprocessing -import random -import uuid -from datetime import datetime, timedelta, timezone - -import shapely -import tqdm -from beaker import ( - Beaker, - Constraints, - DataMount, - DataSource, - EnvVar, - ExperimentSpec, - Priority, - TaskResources, -) -from rasterio.crs import CRS -from rslearn.const import WGS84_PROJECTION -from rslearn.utils.geometry import PixelBounds, Projection, STGeometry -from rslearn.utils.get_utm_ups_crs import get_proj_bounds - -from rslp.launch_beaker import BUDGET, DEFAULT_WORKSPACE, IMAGE_NAME, get_base_env_vars - -from .predict_pipeline import Application, PredictTaskArgs, get_output_fname - -TILE_SIZE = 32768 -RESOLUTION = 10 - -# Days to add before a provided date. -DEFAULT_DAYS_BEFORE = 120 - -# Days to add after a provided date. -DEFAULT_DAYS_AFTER = 90 - - -class Task: - """Represents a task that will correspond to one Beaker job.""" - - def __init__( - self, - application: Application, - projection: Projection, - bounds: PixelBounds, - time_range: tuple[datetime, datetime], - out_path: str, - ) -> None: - """Create a new Task. - - Args: - application: the application to run - projection: the projection of the tile - bounds: the bounds of the tile - time_range: the time range to process - out_path: where to write outputs - """ - self.application = application - self.projection = projection - self.bounds = bounds - self.time_range = time_range - self.out_path = out_path - - -def launch_job(batch: list[Task]) -> None: - """Launch job for this task. - - Args: - batch: list of Task objects for which to create a job. - """ - beaker = Beaker.from_env(default_workspace=DEFAULT_WORKSPACE) - - # Convert tasks to PredictTask. - # These just set projection/bounds/time range, so the application and output path - # come from the first task. - predict_tasks = [] - for task in batch: - predict_tasks.append( - PredictTaskArgs( - projection_json=task.projection.serialize(), - bounds=task.bounds, - time_range=task.time_range, - ) - ) - - with beaker.session(): - env_vars = get_base_env_vars(use_weka_prefix=False) - env_vars.append( - EnvVar( - name="RSLEARN_LOGLEVEL", - value="DEBUG", - ) - ) - - # Name the job based on the first task. - task = batch[0] - experiment_name = ( - f"satlas_{task.application.value}_{task.projection.crs.to_epsg()}_" - + f"{task.bounds[0]}_{task.bounds[1]}" - ) - - spec = ExperimentSpec.new( - budget=BUDGET, - description=experiment_name, - beaker_image=IMAGE_NAME, - priority=Priority.low, - command=["python", "-m", "rslp.main"], - arguments=[ - "satlas", - "predict_multi", - task.application.value.upper(), - task.out_path, - "/tmp/scratch/", - json.dumps( - [predict_task.serialize() for predict_task in predict_tasks] - ), - ], - constraints=Constraints( - cluster=[ - "ai2/jupiter-cirrascale-2", - "ai2/neptune-cirrascale", - "ai2/saturn-cirrascale", - "ai2/augusta-google-1", - # "ai2/prior-cirrascale", - # "ai2/prior-elanding", - ] - ), - preemptible=True, - datasets=[ - DataMount( - source=DataSource(secret="RSLEARN_GCP_CREDENTIALS"), # nosec - mount_path="/etc/credentials/gcp_credentials.json", # nosec - ), - ], - env_vars=env_vars, - resources=TaskResources(gpu_count=1, shared_memory="256GiB"), - ) - unique_id = str(uuid.uuid4())[0:8] - beaker.experiment.create(experiment_name + "_" + unique_id, spec) - - -def check_task_done(task: Task) -> tuple[Task, bool]: - """Checks whether this task is done processing already. - - It is determined based on existence of output file for the task. - - Args: - task: the task. - - Returns: - whether the task was completed - """ - out_fname = get_output_fname( - task.application, task.out_path, task.projection, task.bounds - ) - return task, out_fname.exists() - - -def launch_jobs( - application: Application, - time_range: tuple[datetime, datetime], - out_path: str, - epsg_code: int | None = None, - wgs84_bounds: tuple[float, float, float, float] | None = None, - count: int | None = None, - batch_size: int = 1, -) -> None: - """Launch Beaker jobs for Satlas prediction. - - Args: - application: which application to run. - time_range: the time range to run within. Must have timezone. - out_path: the output path. It should be specific to the time range. - epsg_code: limit tasks to this UTM zone (specified by its EPSG code), default - run in all UTM zones. - wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. - count: only run up to this many tasks. - batch_size: how many tasks to run in each Beaker job. - """ - # Generate tasks. - if epsg_code: - utm_zones = [CRS.from_epsg(epsg_code)] - else: - utm_zones = [] - for epsg_code in range(32601, 32661): - utm_zones.append(CRS.from_epsg(epsg_code)) - for epsg_code in range(32701, 32761): - utm_zones.append(CRS.from_epsg(epsg_code)) - - tasks: list[Task] = [] - for utm_zone in tqdm.tqdm(utm_zones, desc="Enumerating tasks across UTM zones"): - # get_proj_bounds returns bounds in CRS units so we need to convert to pixel - # units. - crs_bbox = STGeometry( - Projection(utm_zone, 1, 1), - shapely.box(*get_proj_bounds(utm_zone)), - None, - ) - projection = Projection(utm_zone, RESOLUTION, -RESOLUTION) - pixel_bbox = crs_bbox.to_projection(projection) - zone_bounds = tuple(int(value) for value in pixel_bbox.shp.bounds) - - user_bounds_in_proj: PixelBounds | None = None - if wgs84_bounds is not None: - dst_geom = STGeometry( - WGS84_PROJECTION, shapely.box(*wgs84_bounds), None - ).to_projection(projection) - user_bounds_in_proj = ( - int(dst_geom.shp.bounds[0]), - int(dst_geom.shp.bounds[1]), - int(dst_geom.shp.bounds[2]), - int(dst_geom.shp.bounds[3]), - ) - - for col in range(zone_bounds[0] // TILE_SIZE, zone_bounds[2] // TILE_SIZE + 1): - for row in range( - zone_bounds[1] // TILE_SIZE, zone_bounds[3] // TILE_SIZE + 1 - ): - if user_bounds_in_proj is not None: - # Check if this task intersects the bounds specified by the user. - if (col + 1) * TILE_SIZE < user_bounds_in_proj[0]: - continue - if col * TILE_SIZE >= user_bounds_in_proj[2]: - continue - if (row + 1) * TILE_SIZE < user_bounds_in_proj[1]: - continue - if row * TILE_SIZE >= user_bounds_in_proj[3]: - continue - - tasks.append( - Task( - application=application, - projection=projection, - bounds=( - col * TILE_SIZE, - row * TILE_SIZE, - (col + 1) * TILE_SIZE, - (row + 1) * TILE_SIZE, - ), - time_range=time_range, - out_path=out_path, - ) - ) - - # See which tasks are not done yet. - p = multiprocessing.Pool(32) - outputs = p.imap_unordered(check_task_done, tasks) - - pending_tasks: list[Task] = [] - for task, is_done in tqdm.tqdm( - outputs, desc="Check which tasks are completed", total=len(tasks) - ): - if is_done: - continue - pending_tasks.append(task) - - p.close() - - # Run up to count of them. - if count is not None and len(pending_tasks) > count: - run_tasks = random.sample(pending_tasks, count) - else: - run_tasks = pending_tasks - - print( - f"Got {len(tasks)} total tasks, {len(pending_tasks)} pending, running {len(run_tasks)} of them" - ) - - batches = [] - for i in range(0, len(run_tasks), batch_size): - batches.append(run_tasks[i : i + batch_size]) - - for batch in tqdm.tqdm(batches, desc="Starting Beaker jobs"): - launch_job(batch) - - -def launch_jobs_for_year_month( - year: int, - month: int, - application: Application, - out_path: str, - batch_size: int = 1, - count: int | None = None, - days_before: int = DEFAULT_DAYS_BEFORE, - days_after: int = DEFAULT_DAYS_AFTER, -) -> None: - """Launch Satlas prediction jobs on Beaker for the given year and month. - - Args: - year: the year. - month: the month. - application: the application to run. - out_path: the output path with year and month placeholders. - batch_size: the batch size. - count: only run up to this many tasks. - days_before: how much to pad windows before the year/month. - days_after: how much to pad windows after the year/month. - """ - ts = datetime(year, month, 1, tzinfo=timezone.utc) - time_range = ( - ts - timedelta(days=days_before), - ts + timedelta(days=days_after), - ) - cur_out_path = out_path.format(year=year, month=month) - print(f"launching jobs with time_range={time_range} and out_path={cur_out_path}") - launch_jobs( - application=application, - time_range=time_range, - out_path=cur_out_path, - batch_size=batch_size, - count=count, - ) diff --git a/rslp/satlas/job_launcher_worker.py b/rslp/satlas/write_jobs.py similarity index 100% rename from rslp/satlas/job_launcher_worker.py rename to rslp/satlas/write_jobs.py From 31421ad72eddf7959b648b47574ac05d65f745e7 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 19 Dec 2024 11:27:52 -0800 Subject: [PATCH 26/58] update readme --- rslp/satlas/README.md | 21 +++++++++++++++++++++ rslp/satlas/publish.py | 39 --------------------------------------- 2 files changed, 21 insertions(+), 39 deletions(-) diff --git a/rslp/satlas/README.md b/rslp/satlas/README.md index 57b773c0..2291454c 100644 --- a/rslp/satlas/README.md +++ b/rslp/satlas/README.md @@ -1,15 +1,36 @@ +This contains training, inference, and post-processing pipelines for the models served +at https://satlas.allen.ai/. + ## Marine Infrastructure +Training: + + python -m rslp.rslearn_main model fit --config data/satlas/marine_infra/config.yaml + Inference: python -m rslp.main satlas write_jobs_for_year_months '[[2024, 7]]' MARINE_INFRA 'gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/{year:04d}-{month:02d}/' skylight-proto-1 rslp-job-queue-favyen Post-processing: + python -m rslp.main satlas merge_points MARINE_INFRA 2024-07 gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/2024-07/ gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/ + python -m rslp.main satlas smooth_points MARINE_INFRA 2024-07 gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/ gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/smoothed/ + python -m rslp.main satlas publish_points MARINE_INFRA gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/smoothed/ 'marine-default-cluster@v4' + ## Wind Turbine +Training: + + python -m rslp.rslearn_main model fit --config data/satlas/wind_turbine/config.yaml + Inference: python -m rslp.main satlas write_jobs_for_year_months '[[2024, 1]]' WIND_TURBINE 'gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/{year:04d}-{month:02d}/' skylight-proto-1 rslp-job-queue-favyen --days_before 90 --days_after 181 Post-processing: + + python -m rslp.main satlas merge_points WIND_TURBINE 2024-01 gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/2024-01/ gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/merged/ + python -m rslp.main satlas smooth_points WIND_TURBINE 2024-01 gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/merged/ gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/smoothed/ + +Publishing for wind turbine is not supported yet since it needs to be combined with the +detected solar farms and published as "renewable energy" GeoJSON. diff --git a/rslp/satlas/publish.py b/rslp/satlas/publish.py index ace959f1..12049f77 100644 --- a/rslp/satlas/publish.py +++ b/rslp/satlas/publish.py @@ -144,61 +144,22 @@ def publish_points( for fname in available_fnames[-NUM_RECOMPUTE:]: logger.info("upload %s", str(fname)) local_geojson_fname = os.path.join(tmp_dir, "data.geojson") - # local_shp_prefix = os.path.join(tmp_dir, "shp_data") - # local_kml_fname = os.path.join(tmp_dir, "data.kml") with fname.open("rb") as src: with open(local_geojson_fname, "wb") as dst: shutil.copyfileobj(src, dst) - """ - subprocess.check_call([ - 'ogr2ogr', - '-F', 'ESRI Shapefile', - '-nlt', 'POINT', - local_shp_prefix + ".shp", - local_geojson_fname, - ]) - make_shapefile_zip(local_shp_prefix) - subprocess.check_call([ - 'ogr2ogr', - '-F', 'KML', - local_kml_fname, - local_geojson_fname, - ]) - """ - fname_prefix = fname.name.split(".")[0] bucket.upload_file( local_geojson_fname, f"outputs/{app_name_on_r2}/{fname_prefix}.geojson", ) - """ - bucket.upload_file( - local_shp_prefix + ".shp.zip", - f"outputs/{app_name_on_r2}/{fname_prefix}.shp.zip", - ) - bucket.upload_file( - local_kml_fname, - f"outputs/{app_name_on_r2}/{fname_prefix}.kml", - ) - """ if fname == available_fnames[-1]: bucket.upload_file( local_geojson_fname, f"outputs/{app_name_on_r2}/latest.geojson", ) - """ - bucket.upload_file( - local_shp_prefix + ".shp.zip", - f"outputs/{app_name_on_r2}/latest.shp.zip", - ) - bucket.upload_file( - local_kml_fname, - f"outputs/{app_name_on_r2}/latest.kml", - ) - """ update_index(bucket, f"outputs/{app_name_on_r2}/") From cd07cfe74858a39c59c502a6513fc8f71cb2a797 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 7 Jan 2025 15:32:35 -0800 Subject: [PATCH 27/58] add azure configs --- data/satlas/wind_turbine/config_azure.yaml | 271 ++++++++++++++++++ .../wind_turbine/config_debug_azure1.yaml | 234 +++++++++++++++ .../wind_turbine/config_debug_azure2.yaml | 238 +++++++++++++++ .../wind_turbine/config_debug_azure3.yaml | 215 ++++++++++++++ .../wind_turbine/config_debug_azure4.yaml | 267 +++++++++++++++++ .../wind_turbine/config_debug_azure5.yaml | 267 +++++++++++++++++ .../wind_turbine/config_debug_azure6.yaml | 266 +++++++++++++++++ rslp/launch_beaker.py | 4 +- rslp/satlas/train.py | 46 +++ 9 files changed, 1807 insertions(+), 1 deletion(-) create mode 100644 data/satlas/wind_turbine/config_azure.yaml create mode 100644 data/satlas/wind_turbine/config_debug_azure1.yaml create mode 100644 data/satlas/wind_turbine/config_debug_azure2.yaml create mode 100644 data/satlas/wind_turbine/config_debug_azure3.yaml create mode 100644 data/satlas/wind_turbine/config_debug_azure4.yaml create mode 100644 data/satlas/wind_turbine/config_debug_azure5.yaml create mode 100644 data/satlas/wind_turbine/config_debug_azure6.yaml diff --git a/data/satlas/wind_turbine/config_azure.yaml b/data/satlas/wind_turbine/config_azure.yaml new file mode 100644 index 00000000..25eb5c68 --- /dev/null +++ b/data/satlas/wind_turbine/config_azure.yaml @@ -0,0 +1,271 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 11 + output_layers: [1, 3, 5, 7] + image_channels: 11 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] + ignore_prefixes: + - "backbone.backbone.backbone.features.0.0." +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + sar1: + data_type: "raster" + layers: ["sentinel1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar2: + data_type: "raster" + layers: ["sentinel1.1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar3: + data_type: "raster" + layers: ["sentinel1.2"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar4: + data_type: "raster" + layers: ["sentinel1.3"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar5: + data_type: "raster" + layers: ["sentinel1.4"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar6: + data_type: "raster" + layers: ["sentinel1.5"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + image_bands: [2, 1, 0] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241212_satlaspretrainold_patch384_00 diff --git a/data/satlas/wind_turbine/config_debug_azure1.yaml b/data/satlas/wind_turbine/config_debug_azure1.yaml new file mode 100644 index 00000000..87c51b97 --- /dev/null +++ b/data/satlas/wind_turbine/config_debug_azure1.yaml @@ -0,0 +1,234 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2_a"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2_a.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2_b"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2_b.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2_c"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2_c.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241002_check_imagenet_performance diff --git a/data/satlas/wind_turbine/config_debug_azure2.yaml b/data/satlas/wind_turbine/config_debug_azure2.yaml new file mode 100644 index 00000000..1728312f --- /dev/null +++ b/data/satlas/wind_turbine/config_debug_azure2.yaml @@ -0,0 +1,238 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2_a"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2_a.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2_b"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2_b.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2_c"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2_c.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241002_recheck_satlaspretrain_performance diff --git a/data/satlas/wind_turbine/config_debug_azure3.yaml b/data/satlas/wind_turbine/config_debug_azure3.yaml new file mode 100644 index 00000000..e07473b6 --- /dev/null +++ b/data/satlas/wind_turbine/config_debug_azure3.yaml @@ -0,0 +1,215 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + image_bands: [2, 1, 0] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241212_only_use_sentinel2 diff --git a/data/satlas/wind_turbine/config_debug_azure4.yaml b/data/satlas/wind_turbine/config_debug_azure4.yaml new file mode 100644 index 00000000..3465dc14 --- /dev/null +++ b/data/satlas/wind_turbine/config_debug_azure4.yaml @@ -0,0 +1,267 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 11 + output_layers: [1, 3, 5, 7] + image_channels: 11 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] + ignore_prefixes: + - "backbone.backbone.backbone.features.0.0." +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + sar1: + data_type: "raster" + layers: ["sentinel1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar2: + data_type: "raster" + layers: ["sentinel1.1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar3: + data_type: "raster" + layers: ["sentinel1.2"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar4: + data_type: "raster" + layers: ["sentinel1.3"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar5: + data_type: "raster" + layers: ["sentinel1.4"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar6: + data_type: "raster" + layers: ["sentinel1.5"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + image_bands: [2, 1, 0] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241212_input_both_but_for_sentinel2_use_old_norm diff --git a/data/satlas/wind_turbine/config_debug_azure5.yaml b/data/satlas/wind_turbine/config_debug_azure5.yaml new file mode 100644 index 00000000..bee15abd --- /dev/null +++ b/data/satlas/wind_turbine/config_debug_azure5.yaml @@ -0,0 +1,267 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 11 + output_layers: [1, 3, 5, 7] + image_channels: 11 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] + ignore_prefixes: + - "backbone.backbone.backbone.features.0.0." +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + sar1: + data_type: "raster" + layers: ["sentinel1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar2: + data_type: "raster" + layers: ["sentinel1.1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar3: + data_type: "raster" + layers: ["sentinel1.2"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar4: + data_type: "raster" + layers: ["sentinel1.3"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar5: + data_type: "raster" + layers: ["sentinel1.4"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar6: + data_type: "raster" + layers: ["sentinel1.5"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + image_bands: [2, 1, 0] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241212_no_freezing diff --git a/data/satlas/wind_turbine/config_debug_azure6.yaml b/data/satlas/wind_turbine/config_debug_azure6.yaml new file mode 100644 index 00000000..cd8df782 --- /dev/null +++ b/data/satlas/wind_turbine/config_debug_azure6.yaml @@ -0,0 +1,266 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslp.satlas.train.TeePipe + init_args: + encoders: + - class_path: rslearn.models.satlaspretrain.SatlasPretrain + init_args: + model_identifier: "Sentinel2_SwinB_MI_MS" + - class_path: rslearn.models.satlaspretrain.SatlasPretrain + init_args: + model_identifier: "Sentinel1_SwinB_MI" + channels: [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11]] + image_channels: 11 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [256, 512, 1024, 2048] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + sar1: + data_type: "raster" + layers: ["sentinel1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar2: + data_type: "raster" + layers: ["sentinel1.1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar3: + data_type: "raster" + layers: ["sentinel1.2"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar4: + data_type: "raster" + layers: ["sentinel1.3"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar5: + data_type: "raster" + layers: ["sentinel1.4"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar6: + data_type: "raster" + layers: ["sentinel1.5"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + image_bands: [2, 1, 0] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241212_dual_encoder diff --git a/rslp/launch_beaker.py b/rslp/launch_beaker.py index 9c505aad..ef10e282 100644 --- a/rslp/launch_beaker.py +++ b/rslp/launch_beaker.py @@ -163,7 +163,9 @@ def launch_job( config_path, "--autoresume=true", ], - constraints=Constraints(cluster=["ai2/jupiter-cirrascale-2"]), + constraints=Constraints( + cluster=["ai2/jupiter-cirrascale-2", "ai2/augusta-google-1"] + ), preemptible=True, datasets=[ DataMount( diff --git a/rslp/satlas/train.py b/rslp/satlas/train.py index 21c7212a..f6e237ec 100644 --- a/rslp/satlas/train.py +++ b/rslp/satlas/train.py @@ -46,3 +46,49 @@ def process_inputs( feat.properties[self.property_name] = CATEGORY_MAPPING[category] return super().process_inputs(raw_inputs, metadata, load_targets) + + +class TeePipe(torch.nn.Module): + """TeePipe passes different channels of the input image to different backbones. + + The features from the different backbones are then concatenated and returned. + """ + + def __init__( + self, + encoders: list[torch.nn.Module], + channels: list[list[int]], + ): + """Create a new TeePipe. + + Args: + encoders: the encoders to apply. + channels: the subset of channels that each encoder should input. For + example, if the input is ABCDEF and first encoder should see ABC while + second should see DEF, then the list should be [[0, 1, 2], [3, 4, 5]]. + """ + self.encoders = encoders + self.channels = channels + + def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]: + """Compute features. + + Args: + inputs: input dicts that must include "image" key containing the images to + process. + """ + # index in feature map -> encoder index -> feature map + all_features: list[list[torch.Tensor]] | None = None + + for encoder, cur_channels in zip(self.encoders, self.channels): + cur_features = encoder( + [{"image": inp["image"][cur_channels, :, :]} for inp in inputs] + ) + if all_features is None: + all_features = [[] for _ in cur_features] + for idx, feat_map in enumerate(cur_features): + all_features[idx].append(feat_map) + + # Final feature map should concatenate at each scale. + assert all_features is not None + return [torch.cat(feat_map_list, dim=1) for feat_map_list in all_features] From f916d5cc977e39ec5fdaf4f08d0a7e443df2707b Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 9 Jan 2025 08:29:38 -0800 Subject: [PATCH 28/58] upgrade solar farm ingestion --- .../lib/__init__.py | 50 ++++------- .../solar_farm/convert_siv_labels.py | 25 ++++-- data/satlas/solar_farm/config.json | 88 +++++++++++++++++++ .../wind_turbine/config_debug_azure6.yaml | 4 +- rslp/satlas/train.py | 31 +++++-- 5 files changed, 150 insertions(+), 48 deletions(-) create mode 100644 data/satlas/solar_farm/config.json diff --git a/convert_satlas_webmercator_to_rslearn/lib/__init__.py b/convert_satlas_webmercator_to_rslearn/lib/__init__.py index 17e18a83..85cac7ba 100644 --- a/convert_satlas_webmercator_to_rslearn/lib/__init__.py +++ b/convert_satlas_webmercator_to_rslearn/lib/__init__.py @@ -14,7 +14,11 @@ from rasterio.crs import CRS from rslearn.const import WGS84_PROJECTION from rslearn.dataset import Window -from rslearn.utils import Projection, STGeometry, get_utm_ups_crs +from rslearn.utils.geometry import Projection, STGeometry +from rslearn.utils.get_utm_ups_crs import get_utm_ups_crs +from rslearn.utils.feature import Feature +from rslearn.utils.vector_format import GeojsonVectorFormat +from rslearn.utils.raster_format import SingleImageRasterFormat from upath import UPath src_crs = CRS.from_epsg(3857) @@ -30,7 +34,7 @@ def convert_window( time_range: tuple[datetime, datetime], dst_pixel_size: float = 10, window_name: str | None = None, -): +) -> Window: """Create an rslearn window from a multisat window with the specified properties. Args: @@ -81,7 +85,7 @@ def convert_window( int(dst_polygon.bounds[2]), int(dst_polygon.bounds[3]), ] - window_root = root_dir / "windows" / group / window_name + window_root = Window.get_window_root(root_dir, group, window_name) window = Window( path=window_root, group=group, @@ -93,7 +97,7 @@ def convert_window( window.save() # (2) Write the turbine positions. - features = [] + features: list[Feature] = [] for shp, properties in labels: # Similar to with bounds, subtract the WebMercator pixel offset between # multisat and rslearn. @@ -101,26 +105,12 @@ def convert_window( src_geom = STGeometry(src_projection, shp, None) dst_geom = src_geom.to_projection(dst_projection) - features.append( - { - "type": "Feature", - "geometry": json.loads(shapely.to_geojson(dst_geom.shp)), - "properties": properties, - } - ) - layer_dir = window_root / "layers" / "label" - label_fname = layer_dir / "data.geojson" - layer_dir.mkdir(parents=True, exist_ok=True) - with label_fname.open("w") as f: - json.dump( - { - "type": "FeatureCollection", - "features": features, - "properties": dst_projection.serialize(), - }, - f, - ) - (layer_dir / "completed").touch() + features.append(Feature(dst_geom, properties)) + + layer_name = "label" + layer_dir = window.get_layer_dir(layer_name) + GeojsonVectorFormat().encode_vector(layer_dir, dst_projection, features) + window.mark_layer_completed(layer_name) # (3) Write mask corresponding to old window projected onto new window. mask = np.zeros((bounds[3] - bounds[1], bounds[2] - bounds[0]), dtype=np.uint8) @@ -130,13 +120,9 @@ def convert_window( polygon_cols = [coord[0] - bounds[0] for coord in dst_polygon.exterior.coords] rr, cc = skimage.draw.polygon(polygon_rows, polygon_cols, shape=mask.shape) mask[rr, cc] = 255 - layer_dir = window_root / "layers" / "mask" - mask_fname = layer_dir / "mask" / "image.png" - mask_fname.parent.mkdir(parents=True, exist_ok=True) - with mask_fname.open("wb") as f: - Image.fromarray(mask).save(f) - with (mask_fname.parent / "bounds.json").open("w") as f: - json.dump(bounds, f) - (layer_dir / "completed").touch() + layer_name = "mask" + layer_dir = window.get_layer_dir(layer_name) + SingleImageRasterFormat().encode_raster(layer_dir, dst_projection, bounds, mask[None, :, :]) + window.mark_layer_completed(layer_name) return window diff --git a/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py b/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py index b449bccc..4bd480f2 100644 --- a/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py +++ b/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py @@ -11,6 +11,8 @@ import shapely from PIL import Image +from rslearn.utils.vector_format import GeojsonVectorFormat +from rslearn.utils.raster_format import SingleImageRasterFormat from ..lib import convert_window db_path = "/home/ubuntu/siv_renewable/data/siv.sqlite3" @@ -36,7 +38,7 @@ ts = ts.replace(tzinfo=timezone.utc) time_range = ( ts - timedelta(days=120), - ts + timedelta(days=30), + ts + timedelta(days=60), ) db.execute( @@ -64,15 +66,20 @@ ) # Create raster version of the label. + layer_name = "label" + layer_dir = window.get_layer_dir(layer_name) + features = GeojsonVectorFormat().decode_vector(layer_dir, bounds) shapes = [] - with window.file_api.open("layers/label/data.geojson", "r") as f: - for feat in json.load(f)["features"]: - geometry = feat["geometry"] - assert geometry["type"] == "Polygon" - geometry["coordinates"] = ( - np.array(geometry["coordinates"]) - [window.bounds[0], window.bounds[1]] - ).tolist() - shapes.append((geometry, 255)) + for feat in features: + geometry = feat.geometry + + + geometry = feat["geometry"] + assert geometry["type"] == "Polygon" + geometry["coordinates"] = ( + np.array(geometry["coordinates"]) - [window.bounds[0], window.bounds[1]] + ).tolist() + shapes.append((geometry, 255)) if shapes: mask = rasterio.features.rasterize( shapes, diff --git a/data/satlas/solar_farm/config.json b/data/satlas/solar_farm/config.json new file mode 100644 index 00000000..cd0ba082 --- /dev/null +++ b/data/satlas/solar_farm/config.json @@ -0,0 +1,88 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "label_raster": { + "band_sets": [ + { + "bands": [ + "label" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_cloud_cover": 50, + "max_time_delta": "0d", + "modality": "L1C", + "name": "rslp.satlas.data_sources.MonthlySentinel2", + "query_config": { + "max_matches": 4 + }, + "sort_by": "cloud_cover", + "use_rtree_index": false + }, + "type": "raster" + } + } +} diff --git a/data/satlas/wind_turbine/config_debug_azure6.yaml b/data/satlas/wind_turbine/config_debug_azure6.yaml index cd8df782..6be7a533 100644 --- a/data/satlas/wind_turbine/config_debug_azure6.yaml +++ b/data/satlas/wind_turbine/config_debug_azure6.yaml @@ -17,7 +17,7 @@ model: - class_path: rslearn.models.satlaspretrain.SatlasPretrain init_args: model_identifier: "Sentinel1_SwinB_MI" - channels: [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11]] + channels: [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10]] image_channels: 11 - class_path: rslearn.models.fpn.Fpn init_args: @@ -147,7 +147,7 @@ data: input_mapping: detect: targets: "targets" - batch_size: 8 + batch_size: 4 num_workers: 32 default_config: transforms: diff --git a/rslp/satlas/train.py b/rslp/satlas/train.py index f6e237ec..77229455 100644 --- a/rslp/satlas/train.py +++ b/rslp/satlas/train.py @@ -67,9 +67,31 @@ def __init__( example, if the input is ABCDEF and first encoder should see ABC while second should see DEF, then the list should be [[0, 1, 2], [3, 4, 5]]. """ - self.encoders = encoders + super().__init__() + self.encoders = torch.nn.ModuleList(encoders) self.channels = channels + def get_backbone_channels(self) -> list: + """Returns the output channels of this model when used as a backbone. + + Returns: + the output channels of the backbone as a list of (downsample_factor, depth) + tuples. + """ + # We assume that each encoder outputs features at matching resolutions. + out_channels = self.encoders[0].get_backbone_channels() + + for encoder in self.encoders[1:]: + cur_channels = encoder.get_backbone_channels() + for idx, (downsample_factor, depth) in enumerate(cur_channels): + if out_channels[idx][0] != downsample_factor: + raise ValueError( + "encoders have mis-matching resolutions of output feature maps" + ) + out_channels[idx][1] += depth + + return out_channels + def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]: """Compute features. @@ -78,17 +100,16 @@ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]: process. """ # index in feature map -> encoder index -> feature map - all_features: list[list[torch.Tensor]] | None = None + all_features: list[list[torch.Tensor]] = [ + [] for _ in self.get_backbone_channels() + ] for encoder, cur_channels in zip(self.encoders, self.channels): cur_features = encoder( [{"image": inp["image"][cur_channels, :, :]} for inp in inputs] ) - if all_features is None: - all_features = [[] for _ in cur_features] for idx, feat_map in enumerate(cur_features): all_features[idx].append(feat_map) # Final feature map should concatenate at each scale. - assert all_features is not None return [torch.cat(feat_map_list, dim=1) for feat_map_list in all_features] From 463fa945b5be73f26543e780df2646d942f77887 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 9 Jan 2025 08:30:43 -0800 Subject: [PATCH 29/58] remove debug configs --- .../wind_turbine/config_debug_azure1.yaml | 234 --------------- .../wind_turbine/config_debug_azure2.yaml | 238 ---------------- .../wind_turbine/config_debug_azure3.yaml | 215 -------------- .../wind_turbine/config_debug_azure4.yaml | 267 ------------------ .../wind_turbine/config_debug_azure5.yaml | 267 ------------------ .../wind_turbine/config_debug_azure6.yaml | 266 ----------------- 6 files changed, 1487 deletions(-) delete mode 100644 data/satlas/wind_turbine/config_debug_azure1.yaml delete mode 100644 data/satlas/wind_turbine/config_debug_azure2.yaml delete mode 100644 data/satlas/wind_turbine/config_debug_azure3.yaml delete mode 100644 data/satlas/wind_turbine/config_debug_azure4.yaml delete mode 100644 data/satlas/wind_turbine/config_debug_azure5.yaml delete mode 100644 data/satlas/wind_turbine/config_debug_azure6.yaml diff --git a/data/satlas/wind_turbine/config_debug_azure1.yaml b/data/satlas/wind_turbine/config_debug_azure1.yaml deleted file mode 100644 index 87c51b97..00000000 --- a/data/satlas/wind_turbine/config_debug_azure1.yaml +++ /dev/null @@ -1,234 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 9 - output_layers: [1, 3, 5, 7] - image_channels: 9 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 2 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2_a"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2_a.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2_b"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image4: - data_type: "raster" - layers: ["sentinel2_b.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image5: - data_type: "raster" - layers: ["sentinel2_c"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image6: - data_type: "raster" - layers: ["sentinel2_c.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: INT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslearn.train.tasks.detection.DetectionTask - init_args: - property_name: "category" - classes: ["unknown", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - image5: [] - image6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 384 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - image5: [] - image6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - groups: ["label", "naip"] - tags: - split: train - val_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - test_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - predict_config: - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - image5: [] - image6: [] - output_selector: image -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0, "encoder", "model"] - unfreeze_at_epoch: 2 -rslp_project: satlas_wind_turbine -rslp_experiment: data_20241002_check_imagenet_performance diff --git a/data/satlas/wind_turbine/config_debug_azure2.yaml b/data/satlas/wind_turbine/config_debug_azure2.yaml deleted file mode 100644 index 1728312f..00000000 --- a/data/satlas/wind_turbine/config_debug_azure2.yaml +++ /dev/null @@ -1,238 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 9 - output_layers: [1, 3, 5, 7] - image_channels: 9 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 2 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2_a"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2_a.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2_b"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image4: - data_type: "raster" - layers: ["sentinel2_b.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image5: - data_type: "raster" - layers: ["sentinel2_c"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image6: - data_type: "raster" - layers: ["sentinel2_c.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: INT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslearn.train.tasks.detection.DetectionTask - init_args: - property_name: "category" - classes: ["unknown", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - image5: [] - image6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 384 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - image5: [] - image6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - groups: ["label", "naip"] - tags: - split: train - val_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - test_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - predict_config: - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - bands: [0, 1, 2] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 8160 - valid_range: [0, 1] - bands: [3, 4, 5, 6, 7, 8] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - image5: [] - image6: [] - output_selector: image -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0, "encoder", "model"] - unfreeze_at_epoch: 2 -rslp_project: satlas_wind_turbine -rslp_experiment: data_20241002_recheck_satlaspretrain_performance diff --git a/data/satlas/wind_turbine/config_debug_azure3.yaml b/data/satlas/wind_turbine/config_debug_azure3.yaml deleted file mode 100644 index e07473b6..00000000 --- a/data/satlas/wind_turbine/config_debug_azure3.yaml +++ /dev/null @@ -1,215 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 9 - output_layers: [1, 3, 5, 7] - image_channels: 9 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 2 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2.2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image4: - data_type: "raster" - layers: ["sentinel2.3"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image5: - data_type: "raster" - layers: ["sentinel2.4"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image6: - data_type: "raster" - layers: ["sentinel2.5"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: INT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslearn.train.tasks.detection.DetectionTask - init_args: - property_name: "category" - classes: ["unknown", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - image_bands: [2, 1, 0] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - image5: [] - image6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 384 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - image5: [] - image6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - groups: ["label", "naip"] - tags: - split: train - val_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - test_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - predict_config: - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - image2: [] - image3: [] - image4: [] - image5: [] - image6: [] - output_selector: image -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max - - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze - init_args: - module_selector: ["model", "encoder", 0, "encoder", "model"] - unfreeze_at_epoch: 2 -rslp_project: satlas_wind_turbine -rslp_experiment: data_20241212_only_use_sentinel2 diff --git a/data/satlas/wind_turbine/config_debug_azure4.yaml b/data/satlas/wind_turbine/config_debug_azure4.yaml deleted file mode 100644 index 3465dc14..00000000 --- a/data/satlas/wind_turbine/config_debug_azure4.yaml +++ /dev/null @@ -1,267 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 11 - output_layers: [1, 3, 5, 7] - image_channels: 11 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 2 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] - ignore_prefixes: - - "backbone.backbone.backbone.features.0.0." -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2.2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image4: - data_type: "raster" - layers: ["sentinel2.3"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image5: - data_type: "raster" - layers: ["sentinel2.4"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image6: - data_type: "raster" - layers: ["sentinel2.5"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - sar1: - data_type: "raster" - layers: ["sentinel1"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar2: - data_type: "raster" - layers: ["sentinel1.1"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar3: - data_type: "raster" - layers: ["sentinel1.2"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar4: - data_type: "raster" - layers: ["sentinel1.3"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar5: - data_type: "raster" - layers: ["sentinel1.4"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar6: - data_type: "raster" - layers: ["sentinel1.5"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: INT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslearn.train.tasks.detection.DetectionTask - init_args: - property_name: "category" - classes: ["unknown", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - image_bands: [2, 1, 0] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - sar1: [] - image2: [] - sar2: [] - image3: [] - sar3: [] - image4: [] - sar4: [] - image5: [] - sar5: [] - image6: [] - sar6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 384 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - sar1: [] - image2: [] - sar2: [] - image3: [] - sar3: [] - image4: [] - sar4: [] - image5: [] - sar5: [] - image6: [] - sar6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - groups: ["label", "naip"] - tags: - split: train - val_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - test_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - predict_config: - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - sar1: [] - image2: [] - sar2: [] - image3: [] - sar3: [] - image4: [] - sar4: [] - image5: [] - sar5: [] - image6: [] - sar6: [] - output_selector: image -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max -rslp_project: satlas_wind_turbine -rslp_experiment: data_20241212_input_both_but_for_sentinel2_use_old_norm diff --git a/data/satlas/wind_turbine/config_debug_azure5.yaml b/data/satlas/wind_turbine/config_debug_azure5.yaml deleted file mode 100644 index bee15abd..00000000 --- a/data/satlas/wind_turbine/config_debug_azure5.yaml +++ /dev/null @@ -1,267 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslearn.models.swin.Swin - init_args: - pretrained: true - input_channels: 11 - output_layers: [1, 3, 5, 7] - image_channels: 11 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [128, 256, 512, 1024] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 2 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 - restore_config: - restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth - remap_prefixes: - - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] - ignore_prefixes: - - "backbone.backbone.backbone.features.0.0." -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2.2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image4: - data_type: "raster" - layers: ["sentinel2.3"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image5: - data_type: "raster" - layers: ["sentinel2.4"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image6: - data_type: "raster" - layers: ["sentinel2.5"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - sar1: - data_type: "raster" - layers: ["sentinel1"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar2: - data_type: "raster" - layers: ["sentinel1.1"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar3: - data_type: "raster" - layers: ["sentinel1.2"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar4: - data_type: "raster" - layers: ["sentinel1.3"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar5: - data_type: "raster" - layers: ["sentinel1.4"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar6: - data_type: "raster" - layers: ["sentinel1.5"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: INT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslearn.train.tasks.detection.DetectionTask - init_args: - property_name: "category" - classes: ["unknown", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - image_bands: [2, 1, 0] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 8 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 10000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - sar1: [] - image2: [] - sar2: [] - image3: [] - sar3: [] - image4: [] - sar4: [] - image5: [] - sar5: [] - image6: [] - sar6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 384 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 10000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - sar1: [] - image2: [] - sar2: [] - image3: [] - sar3: [] - image4: [] - sar4: [] - image5: [] - sar5: [] - image6: [] - sar6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - groups: ["label", "naip"] - tags: - split: train - val_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - test_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - predict_config: - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 10000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - sar1: [] - image2: [] - sar2: [] - image3: [] - sar3: [] - image4: [] - sar4: [] - image5: [] - sar5: [] - image6: [] - sar6: [] - output_selector: image -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max -rslp_project: satlas_wind_turbine -rslp_experiment: data_20241212_no_freezing diff --git a/data/satlas/wind_turbine/config_debug_azure6.yaml b/data/satlas/wind_turbine/config_debug_azure6.yaml deleted file mode 100644 index 6be7a533..00000000 --- a/data/satlas/wind_turbine/config_debug_azure6.yaml +++ /dev/null @@ -1,266 +0,0 @@ -model: - class_path: rslearn.train.lightning_module.RslearnLightningModule - init_args: - model: - class_path: rslearn.models.multitask.MultiTaskModel - init_args: - encoder: - - class_path: rslearn.models.simple_time_series.SimpleTimeSeries - init_args: - encoder: - class_path: rslp.satlas.train.TeePipe - init_args: - encoders: - - class_path: rslearn.models.satlaspretrain.SatlasPretrain - init_args: - model_identifier: "Sentinel2_SwinB_MI_MS" - - class_path: rslearn.models.satlaspretrain.SatlasPretrain - init_args: - model_identifier: "Sentinel1_SwinB_MI" - channels: [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10]] - image_channels: 11 - - class_path: rslearn.models.fpn.Fpn - init_args: - in_channels: [256, 512, 1024, 2048] - out_channels: 128 - decoders: - detect: - - class_path: rslearn.models.faster_rcnn.FasterRCNN - init_args: - downsample_factors: [4, 8, 16, 32] - num_channels: 128 - num_classes: 2 - anchor_sizes: [[32], [64], [128], [256]] - lr: 0.00002 - plateau: true - plateau_factor: 0.2 - plateau_patience: 2 - plateau_min_lr: 0 - plateau_cooldown: 10 -data: - class_path: rslearn.train.data_module.RslearnDataModule - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ - inputs: - image1: - data_type: "raster" - layers: ["sentinel2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image2: - data_type: "raster" - layers: ["sentinel2.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image3: - data_type: "raster" - layers: ["sentinel2.2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image4: - data_type: "raster" - layers: ["sentinel2.3"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image5: - data_type: "raster" - layers: ["sentinel2.4"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - image6: - data_type: "raster" - layers: ["sentinel2.5"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] - passthrough: true - dtype: FLOAT32 - sar1: - data_type: "raster" - layers: ["sentinel1"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar2: - data_type: "raster" - layers: ["sentinel1.1"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar3: - data_type: "raster" - layers: ["sentinel1.2"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar4: - data_type: "raster" - layers: ["sentinel1.3"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar5: - data_type: "raster" - layers: ["sentinel1.4"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - sar6: - data_type: "raster" - layers: ["sentinel1.5"] - bands: ["vv", "vh"] - passthrough: true - dtype: FLOAT32 - mask: - data_type: "raster" - layers: ["mask"] - bands: ["mask"] - passthrough: true - dtype: INT32 - is_target: true - targets: - data_type: "vector" - layers: ["label"] - is_target: true - task: - class_path: rslearn.train.tasks.multi_task.MultiTask - init_args: - tasks: - detect: - class_path: rslearn.train.tasks.detection.DetectionTask - init_args: - property_name: "category" - classes: ["unknown", "turbine"] - box_size: 15 - remap_values: [[0, 1], [0, 255]] - image_bands: [2, 1, 0] - exclude_by_center: true - enable_map_metric: true - enable_f1_metric: true - f1_metric_kwargs: - cmp_mode: "distance" - cmp_threshold: 15 - flatten_classes: true - input_mapping: - detect: - targets: "targets" - batch_size: 4 - num_workers: 32 - default_config: - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 10000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - sar1: [] - image2: [] - sar2: [] - image3: [] - sar3: [] - image4: [] - sar4: [] - image5: [] - sar5: [] - image6: [] - sar6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 384 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 10000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - sar1: [] - image2: [] - sar2: [] - image3: [] - sar3: [] - image4: [] - sar4: [] - image5: [] - sar5: [] - image6: [] - sar6: [] - output_selector: image - - class_path: rslp.transforms.mask.Mask - - class_path: rslearn.train.transforms.flip.Flip - init_args: - image_selectors: ["image"] - box_selectors: ["target/detect"] - groups: ["label", "naip"] - tags: - split: train - val_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - test_config: - patch_size: 384 - groups: ["label", "naip"] - tags: - split: val - predict_config: - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 - transforms: - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 10000 - valid_range: [0, 1] - selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - - class_path: rslearn.train.transforms.concatenate.Concatenate - init_args: - selections: - image1: [] - sar1: [] - image2: [] - sar2: [] - image3: [] - sar3: [] - image4: [] - sar4: [] - image5: [] - sar5: [] - image6: [] - sar6: [] - output_selector: image -trainer: - max_epochs: 500 - callbacks: - - class_path: lightning.pytorch.callbacks.LearningRateMonitor - init_args: - logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ - output_layer: output - selector: ["detect"] - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - save_top_k: 1 - save_last: true - monitor: val_detect/mAP - mode: max -rslp_project: satlas_wind_turbine -rslp_experiment: data_20241212_dual_encoder From 976def921ed63d11dc1653f81c52189afe75cd10 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 9 Jan 2025 09:28:40 -0800 Subject: [PATCH 30/58] add documentation and test (wip) --- rslp/satlas/README.md | 92 ++++++++++++++++++- rslp/satlas/predict_pipeline.py | 42 +-------- .../satlas/test_predict_pipeline.py | 65 +++++++++++++ 3 files changed, 156 insertions(+), 43 deletions(-) create mode 100644 tests/integration_slow/satlas/test_predict_pipeline.py diff --git a/rslp/satlas/README.md b/rslp/satlas/README.md index 2291454c..81bdeb97 100644 --- a/rslp/satlas/README.md +++ b/rslp/satlas/README.md @@ -3,22 +3,106 @@ at https://satlas.allen.ai/. ## Marine Infrastructure -Training: +The marine infrastructure model detects off-shore infrastructure in two categories: +off-shore wind turbines and off-shore platforms. The latter category essentially +includes any manmade object in the ocean that is stationary and would not normally be +considered an artificial island. + +The model inputs four Sentinel-2 images, where each image should be a mosaic that uses +Sentinel-2 scenes from a distinct month. The dataset configuration uses +`MonthlySentinel2` in `rslp/satlas/data_sources.py` to achieve this, and only uses +Sentinel-2 scenes with at most 50% cloud cover. If a given month does not have enough +matching images under the cloud threshold, then images from earlier months may be used. + +The model is meant to be run on a quarterly basis, using images up to 4 months +before the start of the quarter (giving 7 possible months to pick 4 mosaics from). If +all of the months are cloudy, then it is okay to skip that inference and come back to +it in a later season that may be less cloudy. At the same time, we don't want to limit +to just 4 months because a region may never have 4 consecutive months with cloud-free +images available. + +### Training + +The model is trained using the rslearn dataset in `gs://rslearn-eai`. See the model +configuration file for more details. python -m rslp.rslearn_main model fit --config data/satlas/marine_infra/config.yaml -Inference: +### Inference - python -m rslp.main satlas write_jobs_for_year_months '[[2024, 7]]' MARINE_INFRA 'gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/{year:04d}-{month:02d}/' skylight-proto-1 rslp-job-queue-favyen +For inference, the world is split up into ~10K tiles in each of the 60 UTM zones, +yielding 600K inference tasks in total. For each task, an inference worker will: -Post-processing: +1. Create an rslearn dataset with a single window corresponding to the UTM zone/tile. +2. Execute the data ingestion pipeline to populate that window with Sentinel-2 images. +3. Apply the model on the window to create an output layer in the rslearn dataset. +4. Copy the contents of that output layer to a location on GCS. + +The task queue system is implemented using `rslp.common.worker`, see +`rslp/common/README.md` for details. Essentially, we first write tasks to a Google +Cloud Pub/Sub topic, and then launch workers that will read from the topic. + +Then, we start by writing the tasks: + + python -m rslp.main satlas write_jobs_for_year_months '[[2024, 7]]' MARINE_INFRA 'gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/{year:04d}-{month:02d}/' skylight-proto-1 rslp-job-queue-favyen --days_before 120 --days_after 90 + +Here: + +- `[[2024, 7]]` is a list of year-month pairs that we want to run the model on. +- MARINE_INFRA is the application we want to apply. This is an enum for "marine_infra" + and it will automatically use the dataset configuration at + `data/satlas/marine_infra/config.json` and the model configuration at + `data/satlas/marine_infra/config.yaml`. +- `gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/{year:04d}-{month:02d}/` + is the path where outputs should be written. Outputs will be named like + `EPSG:32601_65536_-524288.geojson` where `EPSG:32601` is the UTM projection, and + 65536 and -524288 are the starting column and row (respectively) of the tile. The + path should have a year and month placeholder. +- `skylight-proto-1` is the project of the Pub/Sub topic. +- `rslp-job-queue-favyen` is the name of the Pub/Sub topic. +- The inference tasks should create a window spanning 120 days before the specified + timestamp (to use images before the quarter when necessary) and 90 days after the + timestamp (corresponding to the duration of the quarter). + +Then start the workers. See `rslp/common/README.md` for details. In this example, +`rslp-job-queue-favyen-sub` should be a subscription for the topic to which the tasks +were written. Here we start 100 workers. + + python -m rslp.main common launch skylight-proto-1 rslp-job-queue-favyen-sub 100 --gpus 1 --shared_memory 256GiB + +### Post-processing. + +Post-processing for point tasks occurs locally (does not require starting jobs in parallel). + +First, merge the points computed across all of the different tasks: python -m rslp.main satlas merge_points MARINE_INFRA 2024-07 gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/2024-07/ gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/ + +Here: + +- MARINE_INFRA is the application we want to apply. +- 2024-07 is the timestep label. All timestep labels are YYYY-MM for the Satlas + systems. +- `gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/2024-07/` is the + folder containing inference outputs that we want to merge. +- `gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/` is the + folder to write merged outputs. The output filename will be + `gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/2024-07.geojson`. + +Second, smooth the points across timesteps. This runs a Viterbi smoothing operation. + python -m rslp.main satlas smooth_points MARINE_INFRA 2024-07 gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/ gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/smoothed/ + +Finally, publish the outputs to Cloudflare R2. + python -m rslp.main satlas publish_points MARINE_INFRA gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/smoothed/ 'marine-default-cluster@v4' ## Wind Turbine +The wind turbine model detects wind turbines on land. It is meant to be run on a +semi-annual basis, and inputs six Sentinel-2 images. As with the marine infrastructure +model, each is a mosaic using images within a 30-day period. + Training: python -m rslp.rslearn_main model fit --config data/satlas/wind_turbine/config.yaml diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py index fe2f5e6e..aa4000c1 100644 --- a/rslp/satlas/predict_pipeline.py +++ b/rslp/satlas/predict_pipeline.py @@ -209,11 +209,12 @@ def predict_pipeline( run_model_predict(model_config_fname, ds_path) if APP_IS_RASTER[application]: - """src_fname = window_path / "layers" / "output" / "output" / "geotiff.tif" + src_fname = window_path / "layers" / "output" / "output" / "geotiff.tif" with src_fname.open("rb") as src: with out_fname.open("wb") as dst: - shutil.copyfileobj(src, dst)""" + shutil.copyfileobj(src, dst) + # TODO: implement valid patches and such. raise NotImplementedError else: @@ -250,43 +251,6 @@ def predict_pipeline( "features": [], } - """ - # Add a list specifying which patches are valid vs invalid to the GeoJSON. - # Valid means that none of the input layers are completely zero at the patch. - # This is so that when we smooth the predictions over time, we can distinguish - # a point not being detected because it wasn't there vs not being detected just - # because there was no image available there. - check_images = window_path.glob("layers/*/B02_B03_B04_B08/geotiff.tif") - valid_patches = set() - for check_image in check_images: - path_parts = check_image.path.split("/") - if path_parts[-3] in VALIDITY_EXCLUDE_LAYERS: - continue - - with check_image.open("rb") as f: - with rasterio.open(f) as raster: - valid_mask = raster.read().max(axis=0) > 0 - - for tile_col in range(bounds[0] // PATCH_SIZE, bounds[2] // PATCH_SIZE): - for tile_row in range(bounds[1] // PATCH_SIZE, bounds[3] // PATCH_SIZE): - cur_patch_id = (tile_col, tile_row) - cur_offset = (tile_col * PATCH_SIZE, tile_row * PATCH_SIZE) - - if cur_patch_id in valid_patches: - continue - - # Read from the window that contains this patch. - window = tile_to_window[cur_patch_id] - - - patch_valid = np.zeros((VALIDITY_PATCH_SIZE, VALIDITY_PATCH_SIZE)) - copy_spatial_array(valid_mask, patch_valid, bounds[0:2], cur_offset) - if valid_mask.max() is False: - continue - - valid_patches.add(cur_patch_id) - """ - if "properties" not in fc: fc["properties"] = {} fc["properties"]["valid_patches"] = { diff --git a/tests/integration_slow/satlas/test_predict_pipeline.py b/tests/integration_slow/satlas/test_predict_pipeline.py new file mode 100644 index 00000000..126568d1 --- /dev/null +++ b/tests/integration_slow/satlas/test_predict_pipeline.py @@ -0,0 +1,65 @@ +"""Test the Satlas prediction pipeline.""" + +import json +import pathlib +from datetime import datetime, timezone + +import shapely +from rslearn.const import WGS84_PROJECTION +from rslearn.utils.geometry import STGeometry +from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection + +from rslp.satlas.predict_pipeline import Application, get_output_fname, predict_pipeline + + +def test_predict_pipeline_point(tmp_path: pathlib.Path) -> None: + # Test the prediction pipeline runs correctly for a point detection task. + # Specifically, we apply the marine infrastructure model on a window covering a + # small portion of the Hornsea 2 Offshore Wind Farm, and verify that the resulting + # detections include at least one turbine. + + # These are the coordinates of the known wind turbine. + src_geom = STGeometry(WGS84_PROJECTION, shapely.Point(1.859, 53.91), None) + + # We get the corresponding UTM window. + projection = get_utm_ups_projection(src_geom.shp.x, src_geom.shp.y, 10, -10) + dst_geom = src_geom.to_projection(projection) + bounds = ( + int(dst_geom.shp.x) - 256, + int(dst_geom.shp.y) - 256, + int(dst_geom.shp.x) + 256, + int(dst_geom.shp.y) + 256, + ) + + # The wind farm existed since 2019 so this time range will work. + time_range = ( + datetime(2024, 1, 1, tzinfo=timezone.utc), + datetime(2024, 8, 1, tzinfo=timezone.utc), + ) + + # Output path will contain outputs, while scratch path is used as a working + # directory. Specifically, the scratch path will be populated with an rslearn + # dataset containing one window matching the projection/bounds that we provide. + out_path = tmp_path / "out" + scratch_path = tmp_path / "scratch" + + # Apply the pipeline. It will ingest data and apply the model. + predict_pipeline( + application=Application.MARINE_INFRA, + projection_json=json.dumps(projection.serialize()), + bounds=bounds, + time_range=time_range, + out_path=str(out_path), + scratch_path=str(scratch_path), + ) + + # Now we verify that the output includes at least one turbine. + out_fname = get_output_fname( + Application.MARINE_INFRA, str(out_path), projection, bounds + ) + with out_fname.open() as f: + fc = json.load(f) + turbine_features = [ + feat for feat in fc["features"] if feat["properties"]["category"] == "turbine" + ] + assert len(turbine_features) > 0 From 0bd47f5ef4dd73eb64613cba551a7927f3a2f0d4 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 9 Jan 2025 11:45:26 -0800 Subject: [PATCH 31/58] fix test --- rslp/satlas/predict_pipeline.py | 6 +++++- .../satlas/test_predict_pipeline.py | 20 +++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py index aa4000c1..e319afeb 100644 --- a/rslp/satlas/predict_pipeline.py +++ b/rslp/satlas/predict_pipeline.py @@ -76,6 +76,7 @@ def predict_pipeline( time_range: tuple[datetime, datetime], out_path: str, scratch_path: str, + use_rtree_index: bool = True, ) -> None: """Compute outputs of a Satlas model on this tile. @@ -90,6 +91,9 @@ def predict_pipeline( out_path: directory to write the outputs. It will either be a GeoTIFF or GeoJSON, named based on the bounds. scratch_path: where to store the dataset. + use_rtree_index: whether to prepare using rtree index. This is recommended when + applying the model globally but can be disabled for small regions to avoid + the time to create the index. """ dataset_config_fname = DATASET_CONFIG_FNAME.format(application=application.value) model_config_fname = MODEL_CONFIG_FNAME.format(application=application.value) @@ -120,7 +124,7 @@ def predict_pipeline( continue layer_source_cfg["index_cache_dir"] = str(index_cache_dir) layer_source_cfg["rtree_cache_dir"] = str(UPath(out_path) / "index") - layer_source_cfg["use_rtree_index"] = True + layer_source_cfg["use_rtree_index"] = use_rtree_index layer_source_cfg["rtree_time_range"] = [ time_range[0].isoformat(), time_range[1].isoformat(), diff --git a/tests/integration_slow/satlas/test_predict_pipeline.py b/tests/integration_slow/satlas/test_predict_pipeline.py index 126568d1..ef2a84c3 100644 --- a/tests/integration_slow/satlas/test_predict_pipeline.py +++ b/tests/integration_slow/satlas/test_predict_pipeline.py @@ -9,7 +9,12 @@ from rslearn.utils.geometry import STGeometry from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection -from rslp.satlas.predict_pipeline import Application, get_output_fname, predict_pipeline +from rslp.satlas.predict_pipeline import ( + PATCH_SIZE, + Application, + get_output_fname, + predict_pipeline, +) def test_predict_pipeline_point(tmp_path: pathlib.Path) -> None: @@ -22,14 +27,14 @@ def test_predict_pipeline_point(tmp_path: pathlib.Path) -> None: src_geom = STGeometry(WGS84_PROJECTION, shapely.Point(1.859, 53.91), None) # We get the corresponding UTM window. + # The window's bounds must all be multiples of PATCH_SIZE. projection = get_utm_ups_projection(src_geom.shp.x, src_geom.shp.y, 10, -10) dst_geom = src_geom.to_projection(projection) - bounds = ( - int(dst_geom.shp.x) - 256, - int(dst_geom.shp.y) - 256, - int(dst_geom.shp.x) + 256, - int(dst_geom.shp.y) + 256, + start = ( + int(dst_geom.shp.x) // PATCH_SIZE * PATCH_SIZE, + int(dst_geom.shp.y) // PATCH_SIZE * PATCH_SIZE, ) + bounds = (start[0], start[1], start[0] + PATCH_SIZE, start[1] + PATCH_SIZE) # The wind farm existed since 2019 so this time range will work. time_range = ( @@ -42,8 +47,10 @@ def test_predict_pipeline_point(tmp_path: pathlib.Path) -> None: # dataset containing one window matching the projection/bounds that we provide. out_path = tmp_path / "out" scratch_path = tmp_path / "scratch" + out_path.mkdir() # Apply the pipeline. It will ingest data and apply the model. + # We disable rtree index so that it doesn't need an hour to create it. predict_pipeline( application=Application.MARINE_INFRA, projection_json=json.dumps(projection.serialize()), @@ -51,6 +58,7 @@ def test_predict_pipeline_point(tmp_path: pathlib.Path) -> None: time_range=time_range, out_path=str(out_path), scratch_path=str(scratch_path), + use_rtree_index=False, ) # Now we verify that the output includes at least one turbine. From f8ee3ec1260334775d535e75e4054c98e8bd21c8 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 9 Jan 2025 11:59:23 -0800 Subject: [PATCH 32/58] move convert_satlas_webmercator_to_rslean to one_off_porjects --- .../convert_satlas_webmercator_to_rslearn}/README.md | 0 .../convert_satlas_webmercator_to_rslearn}/__init__.py | 0 .../convert_satlas_webmercator_to_rslearn}/assign_split.py | 0 .../convert_satlas_webmercator_to_rslearn}/config.json | 0 .../convert_satlas_webmercator_to_rslearn}/lib/__init__.py | 0 .../marine_infra/__init__.py | 0 .../marine_infra/convert_siv_labels.py | 0 .../sentinel2_vessel/config.json | 0 .../sentinel2_vessel/convert_siv_labels.py | 0 .../sentinel2_vessel/delete_bad_images.py | 0 .../sentinel2_vessel/reformat_multisat.py | 0 .../set_single_image_metadata.py | 0 .../solar_farm/convert_siv_labels.py | 0 .../wind_turbine/__init__.py | 0 .../wind_turbine/assign_old_splits.py | 0 .../wind_turbine/compare_webmercator_utm_together.py | 0 .../wind_turbine/config.json | 0 .../wind_turbine/config_flip.yaml | 0 .../wind_turbine/config_flip_oldsplit.yaml | 0 .../wind_turbine/config_predict.json | 0 .../wind_turbine/convert_naip_labels.py | 0 .../wind_turbine/convert_siv_labels.py | 0 .../wind_turbine/create_webmercator_rslearn_dataset.py | 0 .../wind_turbine/quick_vis_script.py | 0 .../wind_turbine/webmercator_config.json | 0 .../wind_turbine/webmercator_config.yaml | 0 26 files changed, 0 insertions(+), 0 deletions(-) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/README.md (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/__init__.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/assign_split.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/config.json (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/lib/__init__.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/marine_infra/__init__.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/marine_infra/convert_siv_labels.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/sentinel2_vessel/config.json (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/sentinel2_vessel/convert_siv_labels.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/sentinel2_vessel/delete_bad_images.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/sentinel2_vessel/reformat_multisat.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/set_single_image_metadata.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/solar_farm/convert_siv_labels.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/__init__.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/assign_old_splits.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/compare_webmercator_utm_together.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/config.json (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/config_flip.yaml (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/config_flip_oldsplit.yaml (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/config_predict.json (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/convert_naip_labels.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/convert_siv_labels.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/create_webmercator_rslearn_dataset.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/quick_vis_script.py (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/webmercator_config.json (100%) rename {convert_satlas_webmercator_to_rslearn => one_off_projects/convert_satlas_webmercator_to_rslearn}/wind_turbine/webmercator_config.yaml (100%) diff --git a/convert_satlas_webmercator_to_rslearn/README.md b/one_off_projects/convert_satlas_webmercator_to_rslearn/README.md similarity index 100% rename from convert_satlas_webmercator_to_rslearn/README.md rename to one_off_projects/convert_satlas_webmercator_to_rslearn/README.md diff --git a/convert_satlas_webmercator_to_rslearn/__init__.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/__init__.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/__init__.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/__init__.py diff --git a/convert_satlas_webmercator_to_rslearn/assign_split.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/assign_split.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/assign_split.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/assign_split.py diff --git a/convert_satlas_webmercator_to_rslearn/config.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/config.json similarity index 100% rename from convert_satlas_webmercator_to_rslearn/config.json rename to one_off_projects/convert_satlas_webmercator_to_rslearn/config.json diff --git a/convert_satlas_webmercator_to_rslearn/lib/__init__.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/lib/__init__.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/lib/__init__.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/lib/__init__.py diff --git a/convert_satlas_webmercator_to_rslearn/marine_infra/__init__.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/marine_infra/__init__.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/marine_infra/__init__.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/marine_infra/__init__.py diff --git a/convert_satlas_webmercator_to_rslearn/marine_infra/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/marine_infra/convert_siv_labels.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/marine_infra/convert_siv_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/marine_infra/convert_siv_labels.py diff --git a/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/config.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/config.json similarity index 100% rename from convert_satlas_webmercator_to_rslearn/sentinel2_vessel/config.json rename to one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/config.json diff --git a/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/convert_siv_labels.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/sentinel2_vessel/convert_siv_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/convert_siv_labels.py diff --git a/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/delete_bad_images.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/delete_bad_images.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/sentinel2_vessel/delete_bad_images.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/delete_bad_images.py diff --git a/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/reformat_multisat.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/reformat_multisat.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/sentinel2_vessel/reformat_multisat.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/reformat_multisat.py diff --git a/convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py diff --git a/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/__init__.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/__init__.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/__init__.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/__init__.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/config.json rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_predict.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_predict.json similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/config_predict.json rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_predict.json diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_naip_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_naip_labels.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/convert_naip_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_naip_labels.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_siv_labels.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/convert_siv_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_siv_labels.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/quick_vis_script.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/quick_vis_script.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/quick_vis_script.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/quick_vis_script.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml From 693f9a0d81901a484194a625145a7f29044bd644 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 9 Jan 2025 12:00:02 -0800 Subject: [PATCH 33/58] update readme --- .../convert_satlas_webmercator_to_rslearn/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/README.md b/one_off_projects/convert_satlas_webmercator_to_rslearn/README.md index 18a9ffb5..c7df1dc0 100644 --- a/one_off_projects/convert_satlas_webmercator_to_rslearn/README.md +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/README.md @@ -4,8 +4,8 @@ About This project is for converting the various Satlas application training datasets (wind turbines, solar farms, marine infrastructure) to rslearn format. -TODO: this project should be moved to one_off_projects in the near-future since the -conversion is complete and this code does not need to be maintained anymore. +The conversion has completed so it is now in one_off_projects and the code does not +need to be maintained anymore. Wind Turbines From 4d200c83768d2d4b4236bc141fec9ce3bd419b33 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Mon, 13 Jan 2025 14:59:54 -0800 Subject: [PATCH 34/58] fix --- .../lib/__init__.py | 2 +- .../solar_farm/convert_siv_labels.py | 20 +++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/lib/__init__.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/lib/__init__.py index 85cac7ba..067aa5e3 100644 --- a/one_off_projects/convert_satlas_webmercator_to_rslearn/lib/__init__.py +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/lib/__init__.py @@ -121,7 +121,7 @@ def convert_window( rr, cc = skimage.draw.polygon(polygon_rows, polygon_cols, shape=mask.shape) mask[rr, cc] = 255 layer_name = "mask" - layer_dir = window.get_layer_dir(layer_name) + layer_dir = window.get_raster_dir(layer_name, ["mask"]) SingleImageRasterFormat().encode_raster(layer_dir, dst_projection, bounds, mask[None, :, :]) window.mark_layer_completed(layer_name) diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py index 4bd480f2..4e0b17f3 100644 --- a/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py @@ -9,7 +9,7 @@ import numpy as np import rasterio.features import shapely -from PIL import Image +from upath import UPath from rslearn.utils.vector_format import GeojsonVectorFormat from rslearn.utils.raster_format import SingleImageRasterFormat @@ -57,7 +57,7 @@ labels.append((polygon, properties)) window = convert_window( - root_dir=out_dir, + root_dir=UPath(out_dir), group=group, zoom=15, bounds=bounds, @@ -66,15 +66,13 @@ ) # Create raster version of the label. - layer_name = "label" - layer_dir = window.get_layer_dir(layer_name) + layer_dir = window.get_layer_dir("label") features = GeojsonVectorFormat().decode_vector(layer_dir, bounds) + shapes = [] for feat in features: - geometry = feat.geometry - - - geometry = feat["geometry"] + assert feat.geometry.projection == window.projection + geometry = json.loads(shapely.to_geojson(feat.geometry.shp)) assert geometry["type"] == "Polygon" geometry["coordinates"] = ( np.array(geometry["coordinates"]) - [window.bounds[0], window.bounds[1]] @@ -94,5 +92,7 @@ (window.bounds[3] - window.bounds[1], window.bounds[2] - window.bounds[0]), dtype=np.uint8, ) - with window.file_api.get_folder("layers/label_raster").open("image.png", "wb") as f: - Image.fromarray(mask).save(f, format="PNG") + layer_name = "label_raster" + raster_dir = window.get_raster_dir(layer_name, ["label"]) + SingleImageRasterFormat().encode_raster(raster_dir, window.projection, window.bounds, mask[None, :, :]) + window.mark_layer_completed(layer_name) From 31576ab3f3ecbed5aaec0246e14bfd28f4ee6f4f Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 14 Jan 2025 11:49:54 -0800 Subject: [PATCH 35/58] fix --- .../solar_farm/convert_siv_labels.py | 2 +- rslp/common/README.md | 2 +- rslp/common/worker.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py index 4e0b17f3..5c423fc2 100644 --- a/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py @@ -77,7 +77,7 @@ geometry["coordinates"] = ( np.array(geometry["coordinates"]) - [window.bounds[0], window.bounds[1]] ).tolist() - shapes.append((geometry, 255)) + shapes.append((geometry, 1)) if shapes: mask = rasterio.features.rasterize( shapes, diff --git a/rslp/common/README.md b/rslp/common/README.md index 9c0be6aa..3de1c95c 100644 --- a/rslp/common/README.md +++ b/rslp/common/README.md @@ -24,4 +24,4 @@ Then you can launch the worker. To test on one machine: And to launch 100 workers on Beaker: - python -m rslp.main common launch skylight-proto-1 rslp-job-queue-YOURNAME-sub 100 --gpus 1 --shared_memory 256GiB + python -m rslp.main common launch BEAKER_IMAGE_NAME skylight-proto-1 rslp-job-queue-YOURNAME-sub 100 --gpus 1 --shared_memory 256GiB diff --git a/rslp/common/worker.py b/rslp/common/worker.py index 7c37c163..e6e13eaf 100644 --- a/rslp/common/worker.py +++ b/rslp/common/worker.py @@ -19,7 +19,8 @@ ) from google.cloud import pubsub_v1, storage -from rslp.launch_beaker import BUDGET, DEFAULT_WORKSPACE, IMAGE_NAME, get_base_env_vars +from rslp.launch_beaker import BUDGET, DEFAULT_WORKSPACE +from rslp.launcher_lib import get_base_env_vars from rslp.log_utils import get_logger from rslp.main import run_workflow @@ -196,6 +197,7 @@ def launch_worker(project_id: str, subscription_id: str) -> None: def launch_workers( + image_name: str, project_id: str, subscription_id: str, num_workers: int, @@ -207,6 +209,7 @@ def launch_workers( """Start workers for the prediction jobs. Args: + image_name: the Beaker image name to use for the jobs. project_id: the Google Cloud project ID. subscription_id: the Pub/Sub subscription ID. num_workers: number of workers to launch @@ -224,7 +227,7 @@ def launch_workers( spec = ExperimentSpec.new( budget=BUDGET, description="worker", - beaker_image=IMAGE_NAME, + beaker_image=image_name, priority=priority, command=["python", "-m", "rslp.main"], arguments=[ From 298de6a399726d146cb7b63fc627ca886c1e1a5a Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 17 Jan 2025 13:23:33 -0800 Subject: [PATCH 36/58] only start satlas jobs that weren't already completed --- rslp/satlas/write_jobs.py | 121 ++++++++++++++++++++++++++------------ 1 file changed, 82 insertions(+), 39 deletions(-) diff --git a/rslp/satlas/write_jobs.py b/rslp/satlas/write_jobs.py index ed3aa3a8..245565a1 100644 --- a/rslp/satlas/write_jobs.py +++ b/rslp/satlas/write_jobs.py @@ -2,6 +2,7 @@ import json import random +from collections.abc import Generator from datetime import datetime, timedelta, timezone import shapely @@ -11,10 +12,11 @@ from rslearn.const import WGS84_PROJECTION from rslearn.utils.geometry import PixelBounds, Projection, STGeometry from rslearn.utils.get_utm_ups_crs import get_proj_bounds +from upath import UPath from rslp.log_utils import get_logger -from .predict_pipeline import Application, PredictTaskArgs +from .predict_pipeline import Application, PredictTaskArgs, get_output_fname logger = get_logger(__name__) @@ -54,6 +56,43 @@ def __init__( self.time_range = time_range self.out_path = out_path + def get_output_fname(self) -> UPath: + """Get the output filename that will be used for this task.""" + # The filename format is defined by get_output_fname in predict_pipeline.py. + return get_output_fname( + self.application, self.out_path, self.projection, self.bounds + ) + + +def enumerate_tiles_in_zone(utm_zone: CRS) -> Generator[tuple[int, int], None, None]: + """List all of the tiles in the zone where outputs should be computed. + + The tiles are all TILE_SIZE x TILE_SIZE so only the column/row of the tile along + that grid are returned. + + Args: + utm_zone: the CRS which must correspond to a UTM EPSG. + + Returns: + generator of (column, row) of the tiles that are needed. + """ + # We use get_proj_bounds to get the bounds of the UTM zone in CRS units. + # We then convert to pixel units in order to determine the tiles that are needed. + crs_bbox = STGeometry( + Projection(utm_zone, 1, 1), + shapely.box(*get_proj_bounds(utm_zone)), + None, + ) + projection = Projection(utm_zone, RESOLUTION, -RESOLUTION) + pixel_bbox = crs_bbox.to_projection(projection) + + # Convert the resulting shape to integer bbox. + zone_bounds = tuple(int(value) for value in pixel_bbox.shp.bounds) + + for col in range(zone_bounds[0] // TILE_SIZE, zone_bounds[2] // TILE_SIZE + 1): + for row in range(zone_bounds[1] // TILE_SIZE, zone_bounds[3] // TILE_SIZE + 1): + yield (col, row) + def get_jobs( application: Application, @@ -66,6 +105,8 @@ def get_jobs( ) -> list[list[str]]: """Get batches of tasks for Satlas prediction. + Tasks where outputs have already been computed are excluded. + Args: application: which application to run. time_range: the time range to run within. Must have timezone. @@ -91,17 +132,10 @@ def get_jobs( tasks: list[Task] = [] for utm_zone in tqdm.tqdm(utm_zones, desc="Enumerating tasks across UTM zones"): - # get_proj_bounds returns bounds in CRS units so we need to convert to pixel - # units. - crs_bbox = STGeometry( - Projection(utm_zone, 1, 1), - shapely.box(*get_proj_bounds(utm_zone)), - None, - ) projection = Projection(utm_zone, RESOLUTION, -RESOLUTION) - pixel_bbox = crs_bbox.to_projection(projection) - zone_bounds = tuple(int(value) for value in pixel_bbox.shp.bounds) + # If the user provided WGS84 bounds, then we convert it to pixel coordinates so + # we can check each tile easily. user_bounds_in_proj: PixelBounds | None = None if wgs84_bounds is not None: dst_geom = STGeometry( @@ -114,42 +148,51 @@ def get_jobs( int(dst_geom.shp.bounds[3]), ) - for col in range(zone_bounds[0] // TILE_SIZE, zone_bounds[2] // TILE_SIZE + 1): - for row in range( - zone_bounds[1] // TILE_SIZE, zone_bounds[3] // TILE_SIZE + 1 - ): - if user_bounds_in_proj is not None: - # Check if this task intersects the bounds specified by the user. - if (col + 1) * TILE_SIZE < user_bounds_in_proj[0]: - continue - if col * TILE_SIZE >= user_bounds_in_proj[2]: - continue - if (row + 1) * TILE_SIZE < user_bounds_in_proj[1]: - continue - if row * TILE_SIZE >= user_bounds_in_proj[3]: - continue - - tasks.append( - Task( - application=application, - projection=projection, - bounds=( - col * TILE_SIZE, - row * TILE_SIZE, - (col + 1) * TILE_SIZE, - (row + 1) * TILE_SIZE, - ), - time_range=time_range, - out_path=out_path, - ) + for col, row in enumerate_tiles_in_zone(utm_zone): + if user_bounds_in_proj is not None: + # Check if this task intersects the bounds specified by the user. + if (col + 1) * TILE_SIZE < user_bounds_in_proj[0]: + continue + if col * TILE_SIZE >= user_bounds_in_proj[2]: + continue + if (row + 1) * TILE_SIZE < user_bounds_in_proj[1]: + continue + if row * TILE_SIZE >= user_bounds_in_proj[3]: + continue + + tasks.append( + Task( + application=application, + projection=projection, + bounds=( + col * TILE_SIZE, + row * TILE_SIZE, + (col + 1) * TILE_SIZE, + (row + 1) * TILE_SIZE, + ), + time_range=time_range, + out_path=out_path, ) + ) + + logger.info("Got %d total tasks", len(tasks)) - print(f"Got {len(tasks)} total tasks") + # Remove tasks where outputs are already computed. + existing_output_fnames = {out_fname.name for out_fname in UPath(out_path).iterdir()} + tasks = [ + task + for task in tasks + if task.get_output_fname().name not in existing_output_fnames + ] + logger.info("Got %d tasks that are uncompleted", len(tasks)) + # Sample tasks down to user-provided count (max # tasks to run), if provided. if count is not None and len(tasks) > count: tasks = random.sample(tasks, count) logger.info("Randomly sampled %d tasks", len(tasks)) + # Convert tasks to jobs for use with rslp.common.worker. + # This is what will be written to the Pub/Sub topic. jobs = [] for i in range(0, len(tasks), batch_size): cur_tasks = tasks[i : i + batch_size] From 3c9083c39166fdf53637b8b899ad2242879fe04e Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 30 Jan 2025 23:58:01 -0500 Subject: [PATCH 37/58] enable satlas prediction pipeline to run on jupiter (using /data disk) --- data/satlas/wind_turbine/config_azure.json | 11 +++ rslp/common/worker.py | 80 +++++++++++++--------- rslp/satlas/postprocess.py | 16 ++++- rslp/satlas/predict_pipeline.py | 44 +++++++----- 4 files changed, 100 insertions(+), 51 deletions(-) diff --git a/data/satlas/wind_turbine/config_azure.json b/data/satlas/wind_turbine/config_azure.json index 3dbc34bd..419dcbe3 100644 --- a/data/satlas/wind_turbine/config_azure.json +++ b/data/satlas/wind_turbine/config_azure.json @@ -34,6 +34,17 @@ "data_source": { "ingest": false, "name": "rslp.satlas.data_sources.MonthlySentinel1", + "query": { + "sar:instrument_mode": { + "eq": "IW" + }, + "sar:polarizations": { + "eq": [ + "VV", + "VH" + ] + } + }, "query_config": { "max_matches": 6 } diff --git a/rslp/common/worker.py b/rslp/common/worker.py index e6e13eaf..49c1792f 100644 --- a/rslp/common/worker.py +++ b/rslp/common/worker.py @@ -1,11 +1,17 @@ """Worker to process jobs in a list of jobs.""" import json +import os +import shutil +import signal +import sys +import tempfile import threading import time import uuid +from collections.abc import Callable from concurrent import futures -from datetime import datetime, timedelta, timezone +from typing import Any import tqdm from beaker import ( @@ -17,7 +23,7 @@ Priority, TaskResources, ) -from google.cloud import pubsub_v1, storage +from google.cloud import pubsub_v1 from rslp.launch_beaker import BUDGET, DEFAULT_WORKSPACE from rslp.launcher_lib import get_base_env_vars @@ -30,38 +36,27 @@ # about a pending claim that hasn't completed yet. MAX_JOB_HOURS = 4 +# Scratch directory that jobs can use and it will be managed by this module. +SCRATCH_DIRECTORY = "/tmp/scratch" -def _get_pending_jobs( - jobs: list[list[str]], claim_bucket: storage.Bucket, claim_dir: str -) -> list[int]: - """Get the indices of jobs that haven't been claimed yet. +DATA_DISK = "/data/rslearn_projects" + + +def get_cleanup_signal_handler(tmp_dir: str) -> Callable[[int, Any], None]: + """Make a signal handler that cleans up the specified directory before exiting. + + This should be passed as the handler to signal.signal. Args: - jobs: the full list of jobs. - claim_bucket: bucket where files indicating completed jobs are written. - claim_dir: path within bucket. + tmp_dir: the directory to delete when the signal is received. """ - claimed = set() - # Pending claims are only valid for a few hours. - for blob in claim_bucket.list_blobs(prefix=f"{claim_dir}pending/"): - if datetime.now(timezone.utc) - blob.time_created > timedelta( - hours=MAX_JOB_HOURS - ): - # This is a stale pending claim (the job may have completed, but if so we - # will see its completed blob below). - continue - claimed.add(int(blob.name.split("/")[-1])) - # While completed files indicate that the job is done permanently. - for blob in claim_bucket.list_blobs(prefix=f"{claim_dir}completed/"): - claimed.add(int(blob.name.split("/")[-1])) - - pending = [] - for idx in range(len(jobs)): - if idx in claimed: - continue - pending.append(idx) - - return pending + + def cleanup_signal_handler(signo: int, stack_frame: Any) -> None: + logger.error(f"cleanup_signal_handler: caught signal {signo}") + shutil.rmtree(tmp_dir) + sys.exit(1) + + return cleanup_signal_handler def worker_pipeline( @@ -70,6 +65,7 @@ def worker_pipeline( retries: int = 3, retry_sleep: int = 60, idle_timeout: int = 10, + manage_scratch_dir_on_data_disk: bool = False, ) -> None: """Start a worker to run jobs from a Pub/Sub subscription. @@ -86,7 +82,21 @@ def worker_pipeline( retry_sleep: sleep for this many seconds between retries. Sleeping helps in case there is an error due to rate limiting. idle_timeout: seconds before we terminate if there is no activity. + manage_scratch_dir_on_data_disk: whether to create SCRATCH_DIRECTORY on the + /data disk and manage it to ensure it is deleted in case of SIGTERM. """ + if manage_scratch_dir_on_data_disk: + # Some tasks use SCRATCH_DIRECTORY, and if management is enabled, it means we + # should put the SCRATCH_DIRECTORY on the /data/ disk (via symlink), and that + # we must ensure it is deleted in case SIGTERM is received (i.e. if the Beaker + # job is cancelled or pre-empted. + os.makedirs(DATA_DISK, exist_ok=True) + tmp_dir_on_data_disk = tempfile.TemporaryDirectory(dir=DATA_DISK) + os.symlink(tmp_dir_on_data_disk.name, SCRATCH_DIRECTORY) + signal.signal( + signal.SIGTERM, get_cleanup_signal_handler(tmp_dir_on_data_disk.name) + ) + subscriber = pubsub_v1.SubscriberClient() subscription_path = subscriber.subscription_path(project_id, subscription_id) @@ -144,7 +154,7 @@ def callback(message: pubsub_v1.subscriber.message.Message) -> None: max_messages=1, max_lease_duration=24 * 3600, ) - executor = futures.ThreadPoolExecutor(max_workers=5) + executor = futures.ThreadPoolExecutor(max_workers=1) scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) streaming_pull_future = subscriber.subscribe( subscription_path, @@ -205,6 +215,7 @@ def launch_workers( shared_memory: str | None = None, priority: Priority = Priority.low, cluster: list[str] = ["ai2/augusta-google-1"], + manage_scratch_dir_on_data_disk: bool = False, ) -> None: """Start workers for the prediction jobs. @@ -217,6 +228,7 @@ def launch_workers( shared_memory: shared memory string like "256GiB". priority: priority to assign the Beaker jobs. cluster: clusters to target. + manage_scratch_dir_on_data_disk: see worker_pipeline. """ beaker = Beaker.from_env(default_workspace=DEFAULT_WORKSPACE) @@ -235,6 +247,8 @@ def launch_workers( "worker", project_id, subscription_id, + "--manage_scratch_dir_on_data_disk", + str(manage_scratch_dir_on_data_disk), ], constraints=Constraints( cluster=cluster, @@ -245,6 +259,10 @@ def launch_workers( source=DataSource(secret="RSLEARN_GCP_CREDENTIALS"), # nosec mount_path="/etc/credentials/gcp_credentials.json", # nosec ), + DataMount( + source=DataSource(host_path="/data"), + mount_path="/data", + ), ], env_vars=env_vars, resources=TaskResources(gpu_count=gpus, shared_memory=shared_memory), diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py index f7986c0c..172404e5 100644 --- a/rslp/satlas/postprocess.py +++ b/rslp/satlas/postprocess.py @@ -40,9 +40,19 @@ logger = get_logger(__name__) -def _get_fc(fname: UPath) -> dict[str, Any]: +def _get_fc(fname: UPath) -> tuple[UPath, dict[str, Any]]: + """Read the FeatureCollection from the specified file. + + This is intended to be used as a handler for multiprocessing. + + Args: + fname: the filename to read. + + Returns: + a tuple (fname, fc) of the filename and the decoded FeatureCollection JSON. + """ with fname.open() as f: - return json.load(f) + return fname, json.load(f) def apply_nms( @@ -131,7 +141,7 @@ def merge_points( # Get category remapping in case one is specified for this application. category_map = APP_CATEGORY_MAPS.get(application, {}) - for cur_fc in tqdm.tqdm(outputs, total=len(fnames)): + for fname, cur_fc in tqdm.tqdm(outputs, total=len(fnames)): # The projection information may be missing if there are no valid patches. if "crs" not in cur_fc["properties"]: # Just do some sanity checks, there should be no features and no valid diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py index 9340d343..244ca6ef 100644 --- a/rslp/satlas/predict_pipeline.py +++ b/rslp/satlas/predict_pipeline.py @@ -3,6 +3,7 @@ import json import os import shutil +import tempfile from datetime import datetime from enum import Enum from typing import Any @@ -202,15 +203,26 @@ def predict_pipeline( # Populate the windows. logger.info("materialize dataset") - apply_windows_args = ApplyWindowsArgs(group=group, workers=1) materialize_pipeline_args = MaterializePipelineArgs( disabled_layers=[], - prepare_args=PrepareArgs(apply_windows_args=apply_windows_args), + # Use initial job for prepare since it involves locally caching the tile index + # and other steps that should only be performed once. + prepare_args=PrepareArgs( + apply_windows_args=ApplyWindowsArgs( + group=group, workers=32, use_initial_job=True + ) + ), ingest_args=IngestArgs( - ignore_errors=False, apply_windows_args=apply_windows_args + ignore_errors=False, + apply_windows_args=ApplyWindowsArgs( + group=group, workers=32, use_initial_job=False + ), ), materialize_args=MaterializeArgs( - ignore_errors=False, apply_windows_args=apply_windows_args + ignore_errors=False, + apply_windows_args=ApplyWindowsArgs( + group=group, workers=32, use_initial_job=False + ), ), ) materialize_dataset(ds_path, materialize_pipeline_args=materialize_pipeline_args) @@ -341,17 +353,15 @@ def predict_multi( scratch_path: local directory to use for scratch space. tasks: list of tasks to execute. """ - if os.path.exists(scratch_path): - shutil.rmtree(scratch_path) - + os.makedirs(scratch_path, exist_ok=True) for task in tasks: - predict_pipeline( - application=application, - projection_json=json.dumps(task.projection_json), - bounds=task.bounds, - time_range=task.time_range, - out_path=out_path, - scratch_path=scratch_path, - ) - if os.path.exists(scratch_path): - shutil.rmtree(scratch_path) + with tempfile.TemporaryDirectory(dir=scratch_path) as tmp_dir: + logger.info(f"running task {task} in temporary directory {tmp_dir}") + predict_pipeline( + application=application, + projection_json=json.dumps(task.projection_json), + bounds=task.bounds, + time_range=task.time_range, + out_path=out_path, + scratch_path=os.path.join(tmp_dir, "scratch"), + ) From 1efd54410c844c4c3ce41e67ecf63195959e5ee3 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 30 Jan 2025 23:58:41 -0500 Subject: [PATCH 38/58] add solar farm config --- data/satlas/solar_farm/config.yaml | 207 +++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 data/satlas/solar_farm/config.yaml diff --git a/data/satlas/solar_farm/config.yaml b/data/satlas/solar_farm/config.yaml new file mode 100644 index 00000000..cbc6c71c --- /dev/null +++ b/data/satlas/solar_farm/config.yaml @@ -0,0 +1,207 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + decoders: + segment: + - class_path: rslearn.models.unet.UNetDecoder + init_args: + in_channels: [[4, 128], [8, 256], [16, 512], [32, 1024]] + out_channels: 2 + conv_layers_per_resolution: 2 + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/solar_farm/dataset_v1/20250108/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + dtype: INT32 + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 2 + metric_kwargs: + average: "micro" + remap_values: [[0, 1], [0, 255]] + input_mapping: + segment: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 256 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image", "target/segment/classes", "target/segment/valid"] + tags: + split: train + val_config: + patch_size: 256 + tags: + split: val + test_config: + patch_size: 256 + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/solar_farm/dataset_v1/20250108/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_segment/accuracy + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_solar_farm +rslp_experiment: data_20250108_satlaspretrain_patch256_00 From 9a3d7576b17a796e79d96f5a3118eba48459b9d4 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 6 Feb 2025 08:30:20 -0800 Subject: [PATCH 39/58] add documentation about viterbi smoothing step --- rslp/satlas/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rslp/satlas/README.md b/rslp/satlas/README.md index 81bdeb97..5fa38b78 100644 --- a/rslp/satlas/README.md +++ b/rslp/satlas/README.md @@ -90,6 +90,8 @@ Here: `gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/2024-07.geojson`. Second, smooth the points across timesteps. This runs a Viterbi smoothing operation. +Note that the Viterbi smoothing is implemented in a separate Go application at +`rslp/satlas/scripts/smooth_point_labels_viterbi.go`. python -m rslp.main satlas smooth_points MARINE_INFRA 2024-07 gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/ gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/smoothed/ From ec95814eabcd0e730694d69d668f5731075bbd33 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 6 Feb 2025 08:43:28 -0800 Subject: [PATCH 40/58] add documentation --- rslp/satlas/README.md | 4 +++- rslp/satlas/postprocess.py | 32 ++++++++++++++++++++++++++++++-- rslp/satlas/predict_pipeline.py | 1 + 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/rslp/satlas/README.md b/rslp/satlas/README.md index 5fa38b78..02766d68 100644 --- a/rslp/satlas/README.md +++ b/rslp/satlas/README.md @@ -95,7 +95,9 @@ Note that the Viterbi smoothing is implemented in a separate Go application at python -m rslp.main satlas smooth_points MARINE_INFRA 2024-07 gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/ gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/smoothed/ -Finally, publish the outputs to Cloudflare R2. +Finally, publish the outputs to Cloudflare R2. This requires +[tippecanoe](https://github.com/mapbox/tippecanoe) since it is used to generate the +vector tiles. python -m rslp.main satlas publish_points MARINE_INFRA gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/smoothed/ 'marine-default-cluster@v4' diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py index 172404e5..170f46e3 100644 --- a/rslp/satlas/postprocess.py +++ b/rslp/satlas/postprocess.py @@ -20,13 +20,18 @@ from .predict_pipeline import Application # Approximate maximum meters in one degree latitude/longitude. +# Above/below the equator there will be fewer meters. +# This is used to compute the NMS_DISTANCE_THRESHOLD below. MAX_METERS_PER_DEGREE = 111111 # Threshold on Euclidean distance between lat/lon for NMS. # We just do Euclidean distance for speed/simplicity since NMS doesn't need to be super -# exact. +# exact (instead of spherical distance). NMS_DISTANCE_THRESHOLD = 100 / MAX_METERS_PER_DEGREE +# Individual Satlas applications use different category names than the global ones that +# we want to serve. We should adjust this but for now this map helps to rename the +# categories. APP_CATEGORY_MAPS = { Application.MARINE_INFRA: { "platform": "offshore_platform", @@ -61,6 +66,12 @@ def apply_nms( ) -> list[dict[str, Any]]: """Apply non-maximum suppression over the points. + Although we run NMS inside the object detector, we need to run a global NMS again + because there two levels where we are dividing into patches -- at the global level, + where we start different prediction tasks for every 32768x32768 patch, and again + within the tasks, where we process each 2048x2048 sub-patch. So there can be + redundant detections across these boundaries. + Args: features: the list of JSON Feature objects. distance_threshold: the distance threshold to match points. @@ -77,6 +88,9 @@ def apply_nms( box = (coordinates[0], coordinates[1], coordinates[0], coordinates[1]) grid_index.insert(box, idx) + # Now we iterate over the features and use the index to identify other features + # that are nearby. If the other feature has a higher score then we delete the + # feature. good_features = [] for idx, feat in enumerate(features): coordinates = feat["geometry"]["coordinates"] @@ -141,8 +155,13 @@ def merge_points( # Get category remapping in case one is specified for this application. category_map = APP_CATEGORY_MAPS.get(application, {}) + # Iterate over each of the files produced by a prediction task. + # We merge both the predicted points along with the valid patches (patches + # processed by the task that had available input images). for fname, cur_fc in tqdm.tqdm(outputs, total=len(fnames)): # The projection information may be missing if there are no valid patches. + # In that case we can skip the file since it has neither valid patches that we + # need to track nor any predicted points. if "crs" not in cur_fc["properties"]: # Just do some sanity checks, there should be no features and no valid # patches. @@ -243,9 +262,11 @@ def smooth_points( shutil.copyfileobj(src, dst) labels.append(label) - # Sort by YYYY-MM. + # Sort by YYYY-MM, since the smoothing function expects us to provide all of + # the labels in temporal order. labels.sort() + # Smoothing is handled by a Go script. subprocess.check_call( [ "rslp/satlas/scripts/smooth_point_labels_viterbi", @@ -260,6 +281,7 @@ def smooth_points( ], ) # nosec + # Now we can upload the smoothed per-timestep files. for label in labels: src_path = tmp_smoothed_dir / f"{label}.geojson" dst_path = smoothed_upath / f"{label}.geojson" @@ -267,6 +289,12 @@ def smooth_points( with dst_path.open("wb") as dst: shutil.copyfileobj(src, dst) + # The smoothing also produces a history GeoJSON containing all of the points + # annotated with start/end properties indicating the first and last timesteps + # when the point was detected. (In this case, points detected over time are + # merged into a single GeoJSON feature.) So we upload that too. + # This history file is the one that used to create vector tiles for the web + # application. dst_path = smoothed_upath / "history.geojson" with tmp_hist_fname.open("rb") as src: with dst_path.open("wb") as dst: diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py index 244ca6ef..1f8e2b56 100644 --- a/rslp/satlas/predict_pipeline.py +++ b/rslp/satlas/predict_pipeline.py @@ -236,6 +236,7 @@ def predict_pipeline( else: run_model_predict(model_config_fname, ds_path) + # Merge and upload the outputs. if APP_IS_RASTER[application]: src_fname = window_path / "layers" / "output" / "output" / "geotiff.tif" From ead3e225a6f5b19f2a2fd47b156d0c44ab8328ec Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 6 Feb 2025 14:39:53 -0800 Subject: [PATCH 41/58] remove unused launch_worker --- rslp/common/worker.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/rslp/common/worker.py b/rslp/common/worker.py index 49c1792f..008e95a9 100644 --- a/rslp/common/worker.py +++ b/rslp/common/worker.py @@ -197,15 +197,6 @@ def callback(message: pubsub_v1.subscriber.message.Message) -> None: streaming_pull_future.result() -def launch_worker(project_id: str, subscription_id: str) -> None: - """Launch a worker job. - - Args: - project_id: the Google Cloud project ID. - subscription_id: the Pub/Sub subscription ID. - """ - - def launch_workers( image_name: str, project_id: str, From a71f3f69b052303599d2b521e4c85995572ff926 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 7 Feb 2025 09:40:52 -0800 Subject: [PATCH 42/58] add test for bkt --- rslp/satlas/bkt.py | 324 +++++++++++++++++++-------- tests/integration/satlas/__init__.py | 0 tests/integration/satlas/test_bkt.py | 48 ++++ 3 files changed, 277 insertions(+), 95 deletions(-) create mode 100644 tests/integration/satlas/__init__.py create mode 100644 tests/integration/satlas/test_bkt.py diff --git a/rslp/satlas/bkt.py b/rslp/satlas/bkt.py index e2784457..2de5c6d5 100644 --- a/rslp/satlas/bkt.py +++ b/rslp/satlas/bkt.py @@ -5,9 +5,20 @@ This is similar to https://github.com/mactrem/com-tiles. -The .bkt is just a concatenation of the small files. +The .bkt is just a concatenation of the small files, which here we call items. This is +specialized for storing tiles, so each item falls on a grid at a particular zoom level +and is associated with a column and row. The bkt itself is on a grid at a zoom level +equal to or lower than that of its items, i.e., a coarser grid. -We record the byte offsets in a Google Cloud Bigtable database. +The bkt should contain every item (on the finer grid) that is contained within its tile +(on the coarser grid); if an item is missing, that should mean that there is just no +data there. + +Then there will actually be a set of bkt files at the coarser zoom level to cover the +entire region. + +We record the columns, rows, and byte offsets of items in a Google Cloud Bigtable +database. Readers can first query Bigtable, then make a range read request to GCS. """ import functools @@ -17,6 +28,8 @@ import struct import time from collections.abc import Generator +from dataclasses import dataclass +from enum import StrEnum from typing import Any import google.cloud.bigtable.row @@ -31,31 +44,64 @@ logger = get_logger(__name__) +# Number of retries for reading from BigTable, since sometimes there are transient +# errors. +BIGTABLE_RETRIES = 8 + + +@dataclass +class BktItemMetadata: + """Metadata about an item (small file) stored within the bkt.""" + + # The column and row of the item on the item (fine-grained) grid. + col: int + row: int + + # The byte offset and length of the item within the concatenated bkt file. + offset: int + length: int + + def pack(self) -> bytes: + """Pack the metadata into bytes.""" + return struct.pack(">IIII", self.col, self.row, self.offset, self.length) + + @staticmethod + def unpack(b: bytes) -> "BktItemMetadata": + """Unpack a BktItemMetadata from a 16-byte string.""" + col, row, offset, length = struct.unpack(">IIII", b) + return BktItemMetadata(col, row, offset, length) + class BktInserter: - """A helper class that inserts metadata about bkt files into the database. + """A helper class that stores metadata about bkt files so it can be inserted later. The BktInserter is a separate class from BktWriter so that it can be pickled to support use with multiprocessing. + + Normal usage is to create BktWriter in parallel, call get_inserter to get the + BktInserter objects, then collect them in the main thread and finally call + BktInserter.run. This way the worker threads do not need to make additional + connections to Bigtable, which really becomes a problem. """ def __init__( self, - indexes: list[tuple[int, int, int, int]], + item_metadatas: list[BktItemMetadata], bkt_fname: str, bkt_zoom: int, zoom: int, ): """Create a new BktInserter. + It stores information that will be needed to insert into Bigtable. + Args: - indexes: the byte offsets of the files within the bkt. It is a list of - (col, row, offset, length) tuples. + item_metadatas: metadata about items within the bkt. bkt_fname: the filename where the bkt will be written. bkt_zoom: the zoom level of the bkt. zoom: the zoom level of the tiles within the bkt. """ - self.indexes = indexes + self.item_metadatas = item_metadatas self.bkt_fname = bkt_fname self.bkt_zoom = bkt_zoom self.zoom = zoom @@ -71,19 +117,46 @@ def run(self, bkt_files_table: google.cloud.bigtable.table.Table) -> None: # [indexes] is list of indexes encoded as [4 byte col][4 byte row][4 byte offset][4 byte length]. buf = io.BytesIO() buf.write(struct.pack(">II", self.bkt_zoom, self.zoom)) - for col, row, offset, length in self.indexes: - buf.write(struct.pack(">IIII", col, row, offset, length)) + for item_metadata in self.item_metadatas: + buf.write(item_metadata.pack()) db_row = bkt_files_table.direct_row(self.bkt_fname) db_row.set_cell(b"d", b"d", buf.getvalue()) db_row.commit() class BktWriter: - """Writer for bkt files.""" + """Writer for bkt files. + + Call add to add one item at a time. Then call get_bytes and write the data to GCS. + Finally call insert (or use get_inserter and then pass to main thread and call run + on the BktInserter object). + + Callers must write to GCS before Bigtable so that when clients read from Bigtable + they can expect the files to already be written. + + upload_bkts can help with inserting multiple BktWriters, e.g.: + + bkt_writers = {} + for col, row in item_tiles: + # bkt_factor is 2^(item zoom - bkt zoom), i.e. the difference in scale + # between the item grid and the bkt grid. + bkt_tile = (col//bkt_factor, row//bkt_factor) + if bkt_tile not in bkt_writers: + bkt_writers[bkt_tile] = bkt.BktWriter() + contents = ... + bkt_writers[bkt_tile].add(col, row, contents) + + bkt_uploads = [] + for bkt_tile, bkt_writer in bkt_writers.items(): + out_fname = '.../{}/{}/{}.bkt'.format(bkt_zoom, bkt_tile[0], out_tile[1]) + bkt_uploads.append((bkt_writer, out_fname, args.out_zoom, args.zoom)) + # p is a multiprocessing.Pool. + bkt.upload_bkts(bkt_files_table, p, bkt_uploads) + """ def __init__(self) -> None: """Create a new BktWriter.""" - self.indexes: list[tuple[int, int, int, int]] = [] + self.item_metadatas: list[BktItemMetadata] = [] self.buf = io.BytesIO() self.offset = 0 @@ -97,12 +170,15 @@ def add(self, col: int, row: int, bytes: bytes) -> None: """ offset = self.offset length = len(bytes) - self.indexes.append((col, row, offset, length)) + self.item_metadatas.append(BktItemMetadata(col, row, offset, length)) self.buf.write(bytes) self.offset += length def get_bytes(self) -> bytes: - """Returns the bytes of the whole bkt file.""" + """Returns the bytes of the whole bkt file. + + This is what should be uploaded to GCS. + """ return self.buf.getvalue() def get_inserter(self, bkt_fname: str, bkt_zoom: int, zoom: int) -> "BktInserter": @@ -116,7 +192,7 @@ def get_inserter(self, bkt_fname: str, bkt_zoom: int, zoom: int) -> "BktInserter Returns: a corresponding BktInserter """ - return BktInserter(self.indexes, bkt_fname, bkt_zoom, zoom) + return BktInserter(self.item_metadatas, bkt_fname, bkt_zoom, zoom) def insert( self, @@ -138,154 +214,213 @@ def insert( @functools.cache def get_bucket() -> storage.Bucket: - """Get the GCS bucket where bkt files should be stored.""" + """Get the GCS bucket where bkt files should be stored. + + This comes from the environment variables: + - BKT_PROJECT_ID: GCP project + - BKT_BUCKET_NAME: the GCS bucket within that project. + """ storage_client = storage.Client(project=os.environ["BKT_PROJECT_ID"]) bucket = storage_client.bucket(os.environ["BKT_BUCKET_NAME"]) return bucket -def download_bkt( +@functools.cache +def get_bigtable() -> google.cloud.bigtable.table.Table: + """Get the BigTable table storing bkt metadata.""" + bigtable_client = bigtable.Client(project=os.environ["BKT_BIGTABLE_PROJECT_ID"]) + bigtable_instance = bigtable_client.instance(os.environ["BKT_BIGTABLE_INSTANCE_ID"]) + bkt_files_table = bigtable_instance.table("bkt_files") + return bkt_files_table + + +class DecodeMode(StrEnum): + """Mode indicating how items should be decoded when downloading in parallel. + + This is used in functions like download_bkts so that the worker processes can + handle decoding rather than the caller needing to decode in the main thread. + """ + + # Decode it from image bytes to numpy array. + IMAGE = "image" + + # Yield the bytes directly. + RAW = "raw" + + +@dataclass +class BktDownloadRequest: + """A request to download an item in a bkt file to pass to download_bkts.""" + + # Name of bkt file for this job. + bkt_fname: str + + # Column and row on item (fine-grained) grid to read. + col: int + row: int + + # Arbitrary metadata for use by caller. + # It will be returned with the decoded item data. + metadata: Any = None + + +def _download_bkt( bkt_fname: str, - idx_map: dict[tuple[int, int], tuple[int, int]], - wanted: list[tuple[int, int, Any]], - mode: str, + item_metadatas: list[BktItemMetadata], + wanted: list[BktDownloadRequest], + decode_mode: DecodeMode, ) -> list[tuple[Any, npt.NDArray | bytes]]: """Download tiles in a bkt file. Args: bkt_fname: the bkt filename in the bucket to download from. - idx_map: map from tile (col, row) to (offset, length). - wanted: list of tiles to download. It should be a list of (col, row, metadata) - where metadata can be arbitrary data used by the caller that will be - returned with the tile data (which will be emitted in arbitrary order). - Note that if a tile does not exist within the bkt, it will not be returned - at all. - mode: either "image" to decode image and return numpy array, or "raw" to return - the byte string directly. + item_metadatas: the item metadatas for this bkt file. + wanted: list of BktDownloadRequest that specify the tiles to download. Note + that if a tile does not exist within the bkt, it will not be returned at + all. + decode_mode: how the items should be decoded. Returns: - a list of (metadata, contents) where contents is a numpy array if mode is - "image" or a byte string if mode is "raw". + a list of (metadata, contents) where contents is a numpy array with + DecodeMode.IMAGE or a byte string with DecodeMode.RAW. """ bucket = get_bucket() output = [] # Helper to postprocess an output based on the specified return mode. def add_output(metadata: Any, contents: npt.NDArray | bytes) -> None: - if mode == "image": + if decode_mode == DecodeMode.IMAGE: buf = io.BytesIO(contents) image = skimage.io.imread(buf) output.append((metadata, image)) - elif mode == "raw": + elif decode_mode == DecodeMode.RAW: output.append((metadata, contents)) else: - raise ValueError(f"invalid mode {mode}") + raise ValueError(f"invalid decode mode {decode_mode}") + + # Convert item_metadatas to a map from (col, row) -> (offset, length). + idx_map = {(m.col, m.row): (m.offset, m.length) for m in item_metadatas} - wanted = [ - (col, row, metadata) for col, row, metadata in wanted if (col, row) in idx_map - ] + # Filter for just the requested tiles that actually exist. + # The caller should assume that tiles that are not returned simply don't have data. + wanted = [request for request in wanted if (request.col, request.row) in idx_map] if len(wanted) == 1: - col, row, metadata = wanted[0] - offset, length = idx_map[(col, row)] + # If there is just one requested item within this bkt, we can do a range read + # to read only that item. + request = wanted[0] + offset, length = idx_map[(request.col, request.row)] blob = bucket.blob(bkt_fname) contents = blob.download_as_bytes(start=offset, end=offset + length) - add_output(metadata, contents) + add_output(request.metadata, contents) elif len(wanted) > 1: + # Otherwise, we read the entire bkt file and then extract the segments + # corresponding to the requested items. blob = bucket.blob(bkt_fname) bkt_bytes = blob.download_as_bytes() - for col, row, metadata in wanted: - offset, length = idx_map[(col, row)] + for request in wanted: + offset, length = idx_map[(request.col, request.row)] contents = bkt_bytes[offset : offset + length] - add_output(metadata, contents) + add_output(request.metadata, contents) + # We return a list of (metadata, contents) from this bkt file. + # In download_from_bkt, it will combine these tuples across all of the bkt files. return output -# Parallel download from various bkt files. -# Jobs is a list of (bkt_fname, col, row, metadata). -# download_from_bkt is a generator that will yield (metadata, bytes) for each provided job. +def _bkt_retry_loop( + bkt_files_table: google.cloud.bigtable.table.Table, bkt_fname: str +) -> google.cloud.bigtable.row.PartialRowData: + """Retry loop to read the bkt_fname metadata from BigTable. + + This is used because sometimes there are transient errors reading. + """ + + def attempt_read() -> google.cloud.bigtable.row.PartialRowData: + return bkt_files_table.read_row( + bkt_fname, + filter_=google.cloud.bigtable.row_filters.CellsColumnLimitFilter(1), + ) + + for _ in range(BIGTABLE_RETRIES): + try: + return attempt_read() + except Exception as e: + logger.warning( + f"got error reading bkt_files_table for {bkt_fname} (trying again): {e}" + ) + time.sleep(1) + + # One last read, if it fails then we let the exception go. + return attempt_read() + + def download_from_bkt( bkt_files_table: google.cloud.bigtable.table.Table, - pool: multiprocessing.pool.Pool | None, - jobs: list[tuple[str, int, int, Any]], - mode: str = "raw", + download_requests: list[BktDownloadRequest], + pool: multiprocessing.pool.Pool | None = None, + decode_mode: DecodeMode = DecodeMode.RAW, ) -> Generator[tuple[Any, npt.NDArray | bytes], None, None]: """Download tile contents in parallel from several bkt files. Args: bkt_files_table: the BigTable table containing byte offsets. + download_requests: list of BktDownloadRequest indicating the bkt filenames to + read from along with the item tiles to read. Download requests from the + same bkt_fname will be grouped together so we don't read the same bkt + file multiple times. pool: the multiprocessing pool to use for parallelism, or None to read in current process. - jobs: list of (bkt_fname, col, row, metadata) to work through. Jobs referencing - the same bkt_fname will be grouped together so we don't read the same bkt - file multiple times. - mode: the return mode (see download_bkt). + decode_mode: how the items should be decoded. Yields: - the (metadata, contents) tuples across all of the jobs. + the (metadata, contents) tuples across all of the jobs. Only items that exist + in the bkt files will be returned; if non-existing items are requested, + they would be skipped and caller should assume they have no data. """ - # Get indexes associated with each distinct bkt_fname. - by_bkt_fname: dict[str, list[tuple[int, int, Any]]] = {} - for bkt_fname, col, row, metadata in jobs: - if bkt_fname not in by_bkt_fname: - by_bkt_fname[bkt_fname] = [] - by_bkt_fname[bkt_fname].append((col, row, metadata)) - + # Get tiles to read grouped by each distinct bkt_fname. + requests_by_bkt_fname: dict[str, list[BktDownloadRequest]] = {} + for request in download_requests: + if request.bkt_fname not in requests_by_bkt_fname: + requests_by_bkt_fname[request.bkt_fname] = [] + requests_by_bkt_fname[request.bkt_fname].append(request) + + # Read from BigTable to identify the offset and length of each requested + # (col, row) item. + # We use this to populate a list of jobs (arguments to pass to _download_bkt + # helper). bkt_jobs: list[dict[str, Any]] = [] - for bkt_fname, wanted in by_bkt_fname.items(): - # Use retry loop since we seem to get error reading from BigTable occasionally. - def bkt_retry_loop() -> google.cloud.bigtable.row.PartialRowData: - for _ in range(8): - try: - db_row = bkt_files_table.read_row( - bkt_fname, - filter_=google.cloud.bigtable.row_filters.CellsColumnLimitFilter( - 1 - ), - ) - return db_row - except Exception as e: - print( - f"got error reading bkt_files_table for {bkt_fname} (trying again): {e}" - ) - time.sleep(1) - raise Exception( - f"repeatedly failed to read bkt_files_table for {bkt_fname}" - ) - - db_row = bkt_retry_loop() + for bkt_fname, requests in requests_by_bkt_fname.items(): + db_row = _bkt_retry_loop(bkt_files_table, bkt_fname) # Ignore requested files that don't exist. if not db_row: continue + # Skip 8-byte header with bkt_zoom/zoom. encoded_indexes = db_row.cell_value("d", b"d")[8:] - indexes = {} + item_metadatas = [] for i in range(0, len(encoded_indexes), 16): - col, row, offset, length = struct.unpack( - ">IIII", encoded_indexes[i : i + 16] - ) - indexes[(col, row)] = (offset, length) + item_metadatas.append(BktItemMetadata.unpack(encoded_indexes[i : i + 16])) bkt_jobs.append( dict( bkt_fname=bkt_fname, - idx_map=indexes, - wanted=wanted, - mode=mode, + item_metadatas=item_metadatas, + wanted=requests, + decode_mode=decode_mode, ) ) if pool is None: for job in bkt_jobs: - for metadata, image in download_bkt(**job): + for metadata, image in _download_bkt(**job): yield (metadata, image) else: - outputs = star_imap_unordered(pool, download_bkt, bkt_jobs) + outputs = star_imap_unordered(pool, _download_bkt, bkt_jobs) for output in outputs: for metadata, image in output: yield (metadata, image) @@ -294,6 +429,8 @@ def bkt_retry_loop() -> google.cloud.bigtable.row.PartialRowData: def upload_bkt(bkt_fname: str, contents: bytes) -> None: """Upload a bkt file to GCS bucket. + This is primarily intended to be used as a helper function for multiprocessing. + Args: bkt_fname: the bkt filename within the bucket. contents: the data to upload. @@ -303,7 +440,6 @@ def upload_bkt(bkt_fname: str, contents: bytes) -> None: blob.upload_from_string(contents) -# Tuples is list of (bkt_writer, bkt_fname, bkt_zoom, zoom). def upload_bkts( bkt_files_table: google.cloud.bigtable.table.Table, pool: multiprocessing.pool.Pool, @@ -351,9 +487,7 @@ def make_bkt(src_dir: str, dst_path: str) -> None: {zoom} placeholder where the zoom goes. """ bucket = get_bucket() - bigtable_client = bigtable.Client(project=os.environ["BKT_BIGTABLE_PROJECT_ID"]) - bigtable_instance = bigtable_client.instance(os.environ["BKT_BIGTABLE_INSTANCE_ID"]) - bkt_files_table = bigtable_instance.table("bkt_files") + bkt_files_table = get_bigtable() for zoom_str in os.listdir(src_dir): zoom_dir = os.path.join(src_dir, zoom_str) diff --git a/tests/integration/satlas/__init__.py b/tests/integration/satlas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/satlas/test_bkt.py b/tests/integration/satlas/test_bkt.py new file mode 100644 index 00000000..c790ade5 --- /dev/null +++ b/tests/integration/satlas/test_bkt.py @@ -0,0 +1,48 @@ +"""Test bkt file operations.""" + +import pathlib + +from rslp.satlas.bkt import ( + BktDownloadRequest, + download_from_bkt, + get_bigtable, + make_bkt, +) + + +def test_make_and_download_bkt(tmp_path: pathlib.Path) -> None: + """Test making and downloading a bkt file. + + We create two files and make a bkt from them. + + Then we try to use download_from_bkt to download both of those files and verify + that it is returned correctly. + """ + + # make_bkt expects a directory structure zoom/col/row. + zoom = "1" + data_col0_row0 = b"bkt1" + data_col0_row1 = b"bkt2" + bkt_fname = "tests/test_make_and_download_bkt" + (tmp_path / zoom / "0").mkdir(parents=True) + with (tmp_path / zoom / "0" / "0").open("wb") as f: + f.write(data_col0_row0) + with (tmp_path / zoom / "0" / "1").open("wb") as f: + f.write(data_col0_row1) + make_bkt(str(tmp_path), bkt_fname) + + # Now try to download both of those files. + # We use the metadata to store the expected contents. + # We also read an extra tile that shouldn't exist. + download_requests = [ + BktDownloadRequest(bkt_fname, 0, 0, metadata=data_col0_row0), + BktDownloadRequest(bkt_fname, 0, 1, metadata=data_col0_row1), + # This one should not exist. + BktDownloadRequest(bkt_fname, 1, 1), + ] + bkt_files_table = get_bigtable() + # Call download_from_bkt and populate into list so we can check the length. + results = list(download_from_bkt(bkt_files_table, download_requests)) + assert len(results) == 2 + for expected, actual in results: + assert expected == actual From 1eb9d78b6b1467150a5f587a511ad64e95a8d3fe Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 7 Feb 2025 10:20:53 -0800 Subject: [PATCH 43/58] add test for rslp/common/worker.py --- .github/workflows/build_test.yaml | 7 ++++++ ai2_docs/README.md | 17 +++---------- rslp/common/__init__.py | 2 ++ rslp/common/worker.py | 41 +++++++++++++++++++++++++++++++ rslp/satlas/bkt.py | 4 +-- rslp/satlas/write_jobs.py | 23 +++-------------- 6 files changed, 59 insertions(+), 35 deletions(-) diff --git a/.github/workflows/build_test.yaml b/.github/workflows/build_test.yaml index ca7894ac..f7cdccc7 100644 --- a/.github/workflows/build_test.yaml +++ b/.github/workflows/build_test.yaml @@ -120,6 +120,13 @@ jobs: -e GOOGLE_APPLICATION_CREDENTIALS=/tmp/gcp-credentials.json \ -e RSLP_BUCKET=rslearn-eai \ -e RSLP_PREFIX=gs://rslearn-eai \ + -e BKT_PROJECT_ID=ai2-prior-satlas \ + -e BKT_BUCKET_NAME=satlas-explorer-data \ + -e BKT_BIGTABLE_PROJECT_ID=ai2-prior-satlas \ + -e BKT_BIGTABLE_INSTANCE_ID=satlas \ + -e TEST_PUBSUB_PROJECT=earthsystem-dev-c3po \ + -e TEST_PUBSUB_TOPIC=rslearn_projects_test_topic \ + -e TEST_PUBSUB_SUBSCRIPTION=rslearn_projects_test_subscription \ test pytest tests/integration/ --ignore tests/integration_slow/ -vv diff --git a/ai2_docs/README.md b/ai2_docs/README.md index 563fe89f..2d69cf92 100644 --- a/ai2_docs/README.md +++ b/ai2_docs/README.md @@ -33,6 +33,10 @@ rslearn datasets, model checkpoints, etc. The easiest way is to create a `.env` RSLP_PREFIX=gs://rslearn-eai RSLP_WEKA_PREFIX=weka://dfive-default/rslearn-eai + BKT_PROJECT_ID=ai2-prior-satlas + BKT_BUCKET_NAME=satlas-explorer-data + BKT_BIGTABLE_PROJECT_ID=ai2-prior-satlas + BKT_BIGTABLE_INSTANCE_ID=satlas You will also need to setup GCP credentials that have access to this bucket. @@ -42,19 +46,6 @@ launcher like `rslp.launch_beaker`, then it isn't needed since the credentials a already configured as secrets on the platform, but you would need to setup your Beaker or other platform credentials to be able to launch the jobs. -TODO: update GCP/W&B to use service accounts. - -Currently, until https://github.com/allenai/rslearn/issues/33 is resolved, model config -files use S3-compatable API to access GCS rather than GCS directly. Therefore, you need -to set up environment variables to provide the appropriate credentials: - - S3_ACCESS_KEY_ID=GOOG... - S3_SECRET_ACCESS_KEY=... - -You can create these credentials at -https://console.cloud.google.com/storage/settings;tab=interoperability?hl=en&project=skylight-proto-1 -under "Access keys for your user account". - Usage ----- diff --git a/rslp/common/__init__.py b/rslp/common/__init__.py index d3b0fdc9..f94dc097 100644 --- a/rslp/common/__init__.py +++ b/rslp/common/__init__.py @@ -1,8 +1,10 @@ """Pipelines common across projects.""" from .worker import launch_workers, worker_pipeline +from .write_file import write_file workflows = { "worker": worker_pipeline, "launch": launch_workers, + "write_file": write_file, } diff --git a/rslp/common/worker.py b/rslp/common/worker.py index 008e95a9..05623246 100644 --- a/rslp/common/worker.py +++ b/rslp/common/worker.py @@ -39,6 +39,12 @@ # Scratch directory that jobs can use and it will be managed by this module. SCRATCH_DIRECTORY = "/tmp/scratch" +# Directory to store SCRATCH_DIRECTORY (via symlink) in case +# manage_scratch_dir_on_data_disk is used. +# This is because some Beaker machines have much bigger /data disk than what's +# available for ephemeral storage within the Docker container, so we need to use that +# for disk-intensive tasks to avoid running out of disk space. But we also need to make +# sure we delete everything we wrote, so worker.py manages the folder. DATA_DISK = "/data/rslearn_projects" @@ -150,8 +156,11 @@ def callback(message: pubsub_v1.subscriber.message.Message) -> None: is_processing = False last_message_time = time.time() + # We limit to a single message at a time and a single worker since tasks should use + # all of the available CPU/GPU resources. flow_control = pubsub_v1.types.FlowControl( max_messages=1, + # Tasks may take several hours so we allow extending the lease for up to a day. max_lease_duration=24 * 3600, ) executor = futures.ThreadPoolExecutor(max_workers=1) @@ -163,6 +172,9 @@ def callback(message: pubsub_v1.subscriber.message.Message) -> None: scheduler=scheduler, ) logger.info("worker listening for messages on %s", subscription_path) + # We use the loop below to make the worker exit if there are no more messages (by + # guessing based on the provided idle timeout). We also need to exit if there have + # been too many consecutive errors. try: while True: time.sleep(idle_timeout) @@ -193,6 +205,7 @@ def callback(message: pubsub_v1.subscriber.message.Message) -> None: logger.info("worker exiting due to idle timeout") break finally: + # Exit the worker process. streaming_pull_future.cancel() streaming_pull_future.result() @@ -260,3 +273,31 @@ def launch_workers( ) unique_id = str(uuid.uuid4())[0:8] beaker.experiment.create(f"worker_{unique_id}", spec) + + +def write_jobs( + project_id: str, + topic_id: str, + rslp_project: str, + rslp_workflow: str, + args_list: list[list[str]], +) -> None: + """Write tasks to the PubSub topic. + + Args: + project_id: the project ID for the PubSub topic. + topic_id: the topic ID for the PubSub topic. + rslp_project: the rslp project to run. + rslp_workflow: the workflow in the project to run. + args_list: list of arguments fo reach task. + """ + publisher = pubsub_v1.PublisherClient() + topic_path = publisher.topic_path(project_id, topic_id) + for args in tqdm.tqdm(args_list, desc="Writing jobs to Pub/Sub topic"): + json_data = dict( + project=rslp_project, + workflow=rslp_workflow, + args=args, + ) + data = json.dumps(json_data).encode() + publisher.publish(topic_path, data).result() diff --git a/rslp/satlas/bkt.py b/rslp/satlas/bkt.py index 2de5c6d5..b33f8977 100644 --- a/rslp/satlas/bkt.py +++ b/rslp/satlas/bkt.py @@ -29,7 +29,7 @@ import time from collections.abc import Generator from dataclasses import dataclass -from enum import StrEnum +from enum import Enum from typing import Any import google.cloud.bigtable.row @@ -234,7 +234,7 @@ def get_bigtable() -> google.cloud.bigtable.table.Table: return bkt_files_table -class DecodeMode(StrEnum): +class DecodeMode(str, Enum): """Mode indicating how items should be decoded when downloading in parallel. This is used in functions like download_bkts so that the worker processes can diff --git a/rslp/satlas/write_jobs.py b/rslp/satlas/write_jobs.py index 245565a1..6ea8bd69 100644 --- a/rslp/satlas/write_jobs.py +++ b/rslp/satlas/write_jobs.py @@ -7,13 +7,13 @@ import shapely import tqdm -from google.cloud import pubsub_v1 from rasterio.crs import CRS from rslearn.const import WGS84_PROJECTION from rslearn.utils.geometry import PixelBounds, Projection, STGeometry from rslearn.utils.get_utm_ups_crs import get_proj_bounds from upath import UPath +import rslp.common.worker from rslp.log_utils import get_logger from .predict_pipeline import Application, PredictTaskArgs, get_output_fname @@ -220,23 +220,6 @@ def get_jobs( return jobs -def _write_jobs_to_topic( - jobs: list[list[str]], - project_id: str, - topic_id: str, -) -> None: - publisher = pubsub_v1.PublisherClient() - topic_path = publisher.topic_path(project_id, topic_id) - for job in tqdm.tqdm(jobs, desc="Writing jobs to Pub/Sub topic"): - json_data = dict( - project="satlas", - workflow="predict_multi", - args=job, - ) - data = json.dumps(json_data).encode() - publisher.publish(topic_path, data).result() - - def write_jobs( application: Application, time_range: tuple[datetime, datetime], @@ -271,7 +254,7 @@ def write_jobs( batch_size=batch_size, count=count, ) - _write_jobs_to_topic(jobs, project_id, topic_id) + rslp.common.worker.write_jobs(project_id, topic_id, "satlas", "predict_multi", jobs) def write_jobs_for_year_months( @@ -321,4 +304,4 @@ def write_jobs_for_year_months( jobs.extend(cur_jobs) logger.info("got a total of %d jobs across year-months", len(jobs)) - _write_jobs_to_topic(jobs, project_id, topic_id) + rslp.common.worker.write_jobs(project_id, topic_id, "satlas", "predict_multi", jobs) From c4af2a52fedb283701c43b3f9b4b2e765e1b66a5 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 7 Feb 2025 10:31:26 -0800 Subject: [PATCH 44/58] add doc string --- rslp/satlas/data_sources.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/rslp/satlas/data_sources.py b/rslp/satlas/data_sources.py index e4f43bf8..a20ab3d2 100644 --- a/rslp/satlas/data_sources.py +++ b/rslp/satlas/data_sources.py @@ -20,12 +20,35 @@ def _find_monthly_matches( geometry: STGeometry, item_list: list[Item], period_days: int, max_matches: int ) -> list[list[Item]]: - # Find matches across the periods. - # For each period, we create an STGeometry with modified time range - # matching the period, and obtain matching mosaic. - # We start from the end of the time range because we care more about recent - # periods and so we want to make sure that they align correctly with the - # end. + """Match items to the geometry with one mosaic per period. + + We divide the time range of the geometry into shorter periods. Within each period, + we use the items corresponding to that period to create a mosaic. The returned item + groups include one group per period, starting from the most recent periods, up to + the provided max_matches. + + This is used e.g. when a model should process three mosaics, where each mosaic + should come from a different month. This gives more diversity of images, since + simply searching for the least cloudy images could result in selecting all of the + images from the same month. + + max_matches may be smaller than the total number of periods in the given time + range. In this case, we prefer to use mosaics of the most recent periods. However, + sometimes there may be no items in a period; in that case, the older periods are + used as a fallback. + + Args: + geometry: the window geometry to match items to. + item_list: the list of items. + period_days: the length of one period in days. + max_matches: the number of per-period mosaics to create. + + Returns: + the matched item groups, where each group contains items that yield a + per-period mosaic. + """ + # For each period, we create an STGeometry with modified time range matching that + # period, and use it with match_candidate_items_to_window to get a mosaic. cur_groups: list[list[Item]] = [] period_end = geometry.time_range[1] while period_end > geometry.time_range[0] and len(cur_groups) < max_matches: From df45d4ca03ff5f9b83d3b0109f3f493ecaefd045 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 7 Feb 2025 13:30:29 -0800 Subject: [PATCH 45/58] add tests --- rslp/common/write_file.py | 16 +++ tests/integration/common/__init__.py | 0 tests/integration/common/test_worker.py | 50 ++++++++ tests/integration/satlas/test_data_sources.py | 115 ++++++++++++++++++ 4 files changed, 181 insertions(+) create mode 100644 rslp/common/write_file.py create mode 100644 tests/integration/common/__init__.py create mode 100644 tests/integration/common/test_worker.py create mode 100644 tests/integration/satlas/test_data_sources.py diff --git a/rslp/common/write_file.py b/rslp/common/write_file.py new file mode 100644 index 00000000..47330673 --- /dev/null +++ b/rslp/common/write_file.py @@ -0,0 +1,16 @@ +"""An example workflow that just writes a file. + +This is used for testing. It needs to be under rslp/ so that it can be used as a +workflow. +""" + + +def write_file(fname: str, contents: str) -> None: + """Write the contents to the file. + + Args: + fname: the filename to write. + contents: the data to write to the file. + """ + with open(fname, "w") as f: + f.write(contents) diff --git a/tests/integration/common/__init__.py b/tests/integration/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/common/test_worker.py b/tests/integration/common/test_worker.py new file mode 100644 index 00000000..912cb1f4 --- /dev/null +++ b/tests/integration/common/test_worker.py @@ -0,0 +1,50 @@ +"""Test common.worker functionality.""" + +import os +import pathlib +import time + +from rslp.common.worker import worker_pipeline, write_jobs + + +def test_idle_timeout() -> None: + """Verify that worker exits within about the idle timeout.""" + idle_timeout = 2 + start_time = time.time() + worker_pipeline( + project_id=os.environ["TEST_PUBSUB_PROJECT"], + subscription_id=os.environ["TEST_PUBSUB_SUBSCRIPTION"], + idle_timeout=idle_timeout, + ) + end_time = time.time() + elapsed = end_time - start_time + # Use 3 here in case worker decides to sleep twice (and then some extra time + # elapses). + assert elapsed >= idle_timeout and elapsed <= 3 * idle_timeout + + +def test_single_task(tmp_path: pathlib.Path) -> None: + """Check that worker can do one task.""" + # Use a task that writes to this filename. + dst_fname = tmp_path / "test_file" + # Write the task to the test topic. + job_args = [ + str(dst_fname), + # The contents to write. + "abc", + ] + write_jobs( + project_id=os.environ["TEST_PUBSUB_PROJECT"], + topic_id=os.environ["TEST_PUBSUB_TOPIC"], + rslp_project="common", + rslp_workflow="write_file", + args_list=[job_args], + ) + # Run the worker. + worker_pipeline( + project_id=os.environ["TEST_PUBSUB_PROJECT"], + subscription_id=os.environ["TEST_PUBSUB_SUBSCRIPTION"], + idle_timeout=1, + ) + # Verify that the file was created. + assert dst_fname.exists() diff --git a/tests/integration/satlas/test_data_sources.py b/tests/integration/satlas/test_data_sources.py new file mode 100644 index 00000000..9210d26a --- /dev/null +++ b/tests/integration/satlas/test_data_sources.py @@ -0,0 +1,115 @@ +"""Test Satlas data_sources.py.""" + +import pathlib +from datetime import datetime, timedelta, timezone + +import shapely +from rslearn.config import ( + LayerType, + QueryConfig, + RasterLayerConfig, + SpaceMode, +) +from rslearn.const import WGS84_PROJECTION +from rslearn.data_sources import DataSource +from rslearn.data_sources.azure_sentinel1 import Sentinel1 +from rslearn.data_sources.azure_sentinel2 import Sentinel2 as AzureSentinel2 +from rslearn.data_sources.gcp_public_data import Sentinel2 as Sentinel2 +from rslearn.utils.geometry import STGeometry +from upath import UPath + +from rslp.satlas.data_sources import ( + MonthlyAzureSentinel2, + MonthlySentinel1, + MonthlySentinel2, +) + +PERIOD_DAYS = 30 + + +class TestGetItems: + """Test the get_items method in per-period data source.""" + + def apply_test(self, data_source: DataSource) -> None: + """Test that the data source successfully returns per-period mosaics. + + We apply it on a bbox of Seattle for three-month period with period equal to one + month. + + Args: + data_source: the data source to test. + """ + # Create a 0.002x0.002 degree bbox near Seattle for three-month time range. + seattle_point = (-122.33, 47.61) + shp = shapely.box( + seattle_point[0] - 0.001, + seattle_point[1] - 0.001, + seattle_point[0] + 0.001, + seattle_point[1] + 0.001, + ) + time_range = ( + datetime(2024, 4, 1, tzinfo=timezone.utc), + datetime(2024, 7, 1, tzinfo=timezone.utc), + ) + geometry = STGeometry(WGS84_PROJECTION, shp, time_range) + + # Look for 2 per-month mosaics. + # The first month in the three-month time range should not yield a mosaic since + # it is not needed (it would only be used if the more recent months do not have + # any scenes, but that is not the case here). + query_config = QueryConfig( + space_mode=SpaceMode.MOSAIC, + max_matches=2, + ) + groups = data_source.get_items([geometry], query_config)[0] + + # We expect to get two groups, and each one should be in a different period. + # The groups should be ordered from most recent to least recent. + # There should not be any group for the first period in the three-month time + # range since we expect there to be a mosaic available for the more recent + # periods and max_matches=2. + expected_time_ranges = [ + (time_range[1] - timedelta(days=PERIOD_DAYS), time_range[1]), + ( + time_range[1] - timedelta(days=PERIOD_DAYS * 2), + time_range[1] - timedelta(days=PERIOD_DAYS), + ), + ] + assert len(groups) == len(expected_time_ranges) + for expected_time_range, group in zip(expected_time_ranges, groups): + assert len(group) > 0 + for item in group: + item_ts = item.geometry.time_range[0] + assert expected_time_range[0] <= item_ts <= expected_time_range[1] + + def test_sentinel1(self) -> None: + """Run apply_test with MonthlySentinel1.""" + sentinel1 = MonthlySentinel1( + sentinel1=Sentinel1( + RasterLayerConfig(LayerType.RASTER, []), + ), + period_days=PERIOD_DAYS, + ) + self.apply_test(sentinel1) + + def test_sentinel2(self, tmp_path: pathlib.Path) -> None: + """Run apply_test with MonthlySentinel2.""" + sentinel2 = MonthlySentinel2( + sentinel2=Sentinel2( + RasterLayerConfig(LayerType.RASTER, []), + index_cache_dir=UPath(tmp_path), + use_rtree_index=False, + ), + period_days=PERIOD_DAYS, + ) + self.apply_test(sentinel2) + + def test_azure_sentinel2(self) -> None: + """Run apply_test with MonthlyAzureSentinel2.""" + sentinel2 = MonthlyAzureSentinel2( + sentinel2=AzureSentinel2( + RasterLayerConfig(LayerType.RASTER, []), + ), + period_days=PERIOD_DAYS, + ) + self.apply_test(sentinel2) From 731cd0c584607b4768cc4989e11284bd2f1460ff Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 7 Feb 2025 14:44:20 -0800 Subject: [PATCH 46/58] add test for apply_nms --- rslp/satlas/postprocess.py | 6 ++-- tests/unit/satlas/__init__.py | 0 tests/unit/satlas/test_postprocess.py | 45 +++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) create mode 100644 tests/unit/satlas/__init__.py create mode 100644 tests/unit/satlas/test_postprocess.py diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py index 170f46e3..d4e52b93 100644 --- a/rslp/satlas/postprocess.py +++ b/rslp/satlas/postprocess.py @@ -73,7 +73,9 @@ def apply_nms( redundant detections across these boundaries. Args: - features: the list of JSON Feature objects. + features: the list of JSON Feature objects. The features must be Point with + longitude/latitude coordinates, and must have a "score" property indicating + the confidence. distance_threshold: the distance threshold to match points. Returns: @@ -106,7 +108,7 @@ def apply_nms( other_feat = features[other_idx] if idx == other_idx: continue - if feat["properties"]["score"] < other_feat["properties"]["score"]: + if feat["properties"]["score"] > other_feat["properties"]["score"]: continue other_coordinates = other_feat["geometry"]["coordinates"] distance = math.sqrt( diff --git a/tests/unit/satlas/__init__.py b/tests/unit/satlas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/satlas/test_postprocess.py b/tests/unit/satlas/test_postprocess.py new file mode 100644 index 00000000..ccf3c8bc --- /dev/null +++ b/tests/unit/satlas/test_postprocess.py @@ -0,0 +1,45 @@ +"""Test Satlas post-processing.""" + +from typing import Any + +from rslp.satlas.postprocess import apply_nms + + +class TestApplyNMS: + """Test the apply_nms function.""" + + DISTANCE_THRESHOLD = 1 + + def make_feature(self, lon: float, lat: float, score: float) -> dict[str, Any]: + """Helper function to create a GeoJSON Feature dict.""" + return { + "type": "Feature", + "geometry": { + "type": "Point", + "coordinates": [lon, lat], + }, + "properties": { + "score": score, + }, + } + + def test_keep_far_away(self) -> None: + """Ensure that points sufficiently far away from each other are retained.""" + features = [ + self.make_feature(0, 0, 1), + self.make_feature(self.DISTANCE_THRESHOLD * 2, 0, 1), + ] + result = apply_nms(features, distance_threshold=self.DISTANCE_THRESHOLD) + assert len(result) == 2 + + def test_remove_two_of_three(self) -> None: + """With three close together points, remove the two lower confidence ones.""" + features = [ + self.make_feature(0, 0, 0.5), + self.make_feature(0.1, 0.1, 0.6), # best one + self.make_feature(0.2, 0.2, 0.4), + ] + result = apply_nms(features, distance_threshold=self.DISTANCE_THRESHOLD) + assert len(result) == 1 + feature = result[0] + assert feature["geometry"]["coordinates"] == [0.1, 0.1] From 98d0925e55dae5a33cb17bcb4631573f0c2cd84a Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 7 Feb 2025 15:19:43 -0800 Subject: [PATCH 47/58] add merge_points test --- rslp/satlas/postprocess.py | 3 + tests/unit/satlas/test_postprocess.py | 104 +++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py index d4e52b93..4cad6a91 100644 --- a/rslp/satlas/postprocess.py +++ b/rslp/satlas/postprocess.py @@ -161,6 +161,8 @@ def merge_points( # We merge both the predicted points along with the valid patches (patches # processed by the task that had available input images). for fname, cur_fc in tqdm.tqdm(outputs, total=len(fnames)): + logger.debug("merging points from %s", fname) + # The projection information may be missing if there are no valid patches. # In that case we can skip the file since it has neither valid patches that we # need to track nor any predicted points. @@ -205,6 +207,7 @@ def merge_points( p.close() merged_upath = UPath(merged_path) + merged_upath.mkdir(parents=True, exist_ok=True) merged_fname = merged_upath / f"{label}.geojson" with merged_fname.open("w") as f: json.dump( diff --git a/tests/unit/satlas/test_postprocess.py b/tests/unit/satlas/test_postprocess.py index ccf3c8bc..62b4580c 100644 --- a/tests/unit/satlas/test_postprocess.py +++ b/tests/unit/satlas/test_postprocess.py @@ -1,8 +1,14 @@ """Test Satlas post-processing.""" +import json +import pathlib from typing import Any -from rslp.satlas.postprocess import apply_nms +from rasterio.crs import CRS +from rslearn.utils.geometry import Projection + +from rslp.satlas.postprocess import apply_nms, merge_points +from rslp.satlas.predict_pipeline import Application class TestApplyNMS: @@ -43,3 +49,99 @@ def test_remove_two_of_three(self) -> None: assert len(result) == 1 feature = result[0] assert feature["geometry"]["coordinates"] == [0.1, 0.1] + + +class TestMergePoints: + """Test the merge_points function.""" + + def make_task_output( + self, + fname: pathlib.Path, + projection: Projection, + coords: list[tuple[float, float]], + valid_patches: list[tuple[int, int]], + ) -> None: + """Make a JSON matching those produced by the Satlas predict_pipeline. + + Args: + fname: the filename to write to. + projection: the projection of the prediction task. The task writes the + GeoJSON in pixel coordinates under that projection. + coords: list of point (col, row) coordinates to include. + valid_patches: list of (col, row) patches to include. These are in tiles of + PATCH_SIZE (see rslp.satlas.predict_pipeline). + """ + # Convert features to GeoJSON. + features = [] + for col, row in coords: + features.append( + { + "type": "Feature", + "properties": { + "score": 1, + "category": "placeholder", + }, + "geometry": { + "type": "Point", + "coordinates": [col, row], + }, + } + ) + + # Make the FeatureCollection. + fc = { + "type": "FeatureCollection", + "features": features, + "properties": projection.serialize(), + } + # Add the valid patches. It is a dict from CRS to tile list. + fc["properties"]["valid_patches"] = { + str(projection.crs): valid_patches, + } + + fname.parent.mkdir(parents=True, exist_ok=True) + with open(fname, "w") as f: + json.dump(fc, f) + + def two_crs_merged(self, tmp_path: pathlib.Path) -> None: + """Verify that when merging across two CRS it is successful.""" + proj32601 = Projection(CRS.from_epsg(32601), 10, -10) + proj32602 = Projection(CRS.from_epsg(32602), 10, -10) + + predict_path = tmp_path / "predict" + merged_path = tmp_path / "merged" + + # File 1 is in 32601 and contains one feature. + self.make_task_output(predict_path / "1.geojson", proj32601, [(0, 0)], [(0, 0)]) + # File 2 is also in 32601 and contains one different feature and patch. + self.make_task_output( + predict_path / "2.geojson", proj32601, [(2048, 2048)], [(1, 1)] + ) + # File 3 is in 32602 and contains a third feature. + self.make_task_output(predict_path / "3.geojson", proj32602, [(0, 0)], [(0, 0)]) + + # Run the merging. + merge_points( + Application.MARINE_INFRA, + # Use arbitrary YYYY-MM label. + "1234-56", + str(predict_path), + str(merged_path), + ) + + # Verify the output. + merged_fname = merged_path / "1234-56.geojson" + with merged_fname.open() as f: + fc = json.load(f) + # And the valid patches should be merged, with one in 32601 and two in 32602. + valid_patches = fc["properties"]["valid_patches"] + assert len(valid_patches) == 2 + patches32601 = valid_patches[str(proj32601)] + patches32601.sort() + assert patches32601 == [[0, 0], [1, 1]] + patches32602 = valid_patches[str(proj32602)] + patches32602.sort() + assert patches32602 == [[0, 0]] + + features = fc["features"] + assert len(features) == 3 From d7810bf40667ccc3a829ccbe115f58e5afc8ee26 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 7 Feb 2025 15:20:41 -0800 Subject: [PATCH 48/58] remove unused nms stuff --- rslp/satlas/postprocess.py | 78 --------------------------- tests/unit/satlas/test_postprocess.py | 43 +-------------- 2 files changed, 1 insertion(+), 120 deletions(-) diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py index 4cad6a91..850b2d5f 100644 --- a/rslp/satlas/postprocess.py +++ b/rslp/satlas/postprocess.py @@ -1,7 +1,6 @@ """Postprocessing outputs from Satlas models.""" import json -import math import multiprocessing import shutil import subprocess # nosec @@ -12,23 +11,12 @@ import tqdm from rslearn.const import WGS84_PROJECTION from rslearn.utils.geometry import Projection, STGeometry -from rslearn.utils.grid_index import GridIndex from upath import UPath from rslp.log_utils import get_logger from .predict_pipeline import Application -# Approximate maximum meters in one degree latitude/longitude. -# Above/below the equator there will be fewer meters. -# This is used to compute the NMS_DISTANCE_THRESHOLD below. -MAX_METERS_PER_DEGREE = 111111 - -# Threshold on Euclidean distance between lat/lon for NMS. -# We just do Euclidean distance for speed/simplicity since NMS doesn't need to be super -# exact (instead of spherical distance). -NMS_DISTANCE_THRESHOLD = 100 / MAX_METERS_PER_DEGREE - # Individual Satlas applications use different category names than the global ones that # we want to serve. We should adjust this but for now this map helps to rename the # categories. @@ -60,72 +48,6 @@ def _get_fc(fname: UPath) -> tuple[UPath, dict[str, Any]]: return fname, json.load(f) -def apply_nms( - features: list[dict[str, Any]], - distance_threshold: float, -) -> list[dict[str, Any]]: - """Apply non-maximum suppression over the points. - - Although we run NMS inside the object detector, we need to run a global NMS again - because there two levels where we are dividing into patches -- at the global level, - where we start different prediction tasks for every 32768x32768 patch, and again - within the tasks, where we process each 2048x2048 sub-patch. So there can be - redundant detections across these boundaries. - - Args: - features: the list of JSON Feature objects. The features must be Point with - longitude/latitude coordinates, and must have a "score" property indicating - the confidence. - distance_threshold: the distance threshold to match points. - - Returns: - new Features with NMS applied. - """ - # A few multiples of the distance threshold is generally a good grid size. - grid_index = GridIndex(distance_threshold * 10) - - # Insert features into the index. - for idx, feat in enumerate(features): - coordinates = feat["geometry"]["coordinates"] - box = (coordinates[0], coordinates[1], coordinates[0], coordinates[1]) - grid_index.insert(box, idx) - - # Now we iterate over the features and use the index to identify other features - # that are nearby. If the other feature has a higher score then we delete the - # feature. - good_features = [] - for idx, feat in enumerate(features): - coordinates = feat["geometry"]["coordinates"] - # Create search box with distance threshold padding. - box = ( - coordinates[0] - distance_threshold, - coordinates[1] - distance_threshold, - coordinates[0] + distance_threshold, - coordinates[1] + distance_threshold, - ) - is_feat_okay = True - for other_idx in grid_index.query(box): - other_feat = features[other_idx] - if idx == other_idx: - continue - if feat["properties"]["score"] > other_feat["properties"]["score"]: - continue - other_coordinates = other_feat["geometry"]["coordinates"] - distance = math.sqrt( - (coordinates[0] - other_coordinates[0]) ** 2 - + (coordinates[1] - other_coordinates[1]) ** 2 - ) - if distance > distance_threshold: - continue - is_feat_okay = False - break - - if is_feat_okay: - good_features.append(feat) - - return good_features - - def merge_points( application: Application, label: str, diff --git a/tests/unit/satlas/test_postprocess.py b/tests/unit/satlas/test_postprocess.py index 62b4580c..78cbc897 100644 --- a/tests/unit/satlas/test_postprocess.py +++ b/tests/unit/satlas/test_postprocess.py @@ -2,55 +2,14 @@ import json import pathlib -from typing import Any from rasterio.crs import CRS from rslearn.utils.geometry import Projection -from rslp.satlas.postprocess import apply_nms, merge_points +from rslp.satlas.postprocess import merge_points from rslp.satlas.predict_pipeline import Application -class TestApplyNMS: - """Test the apply_nms function.""" - - DISTANCE_THRESHOLD = 1 - - def make_feature(self, lon: float, lat: float, score: float) -> dict[str, Any]: - """Helper function to create a GeoJSON Feature dict.""" - return { - "type": "Feature", - "geometry": { - "type": "Point", - "coordinates": [lon, lat], - }, - "properties": { - "score": score, - }, - } - - def test_keep_far_away(self) -> None: - """Ensure that points sufficiently far away from each other are retained.""" - features = [ - self.make_feature(0, 0, 1), - self.make_feature(self.DISTANCE_THRESHOLD * 2, 0, 1), - ] - result = apply_nms(features, distance_threshold=self.DISTANCE_THRESHOLD) - assert len(result) == 2 - - def test_remove_two_of_three(self) -> None: - """With three close together points, remove the two lower confidence ones.""" - features = [ - self.make_feature(0, 0, 0.5), - self.make_feature(0.1, 0.1, 0.6), # best one - self.make_feature(0.2, 0.2, 0.4), - ] - result = apply_nms(features, distance_threshold=self.DISTANCE_THRESHOLD) - assert len(result) == 1 - feature = result[0] - assert feature["geometry"]["coordinates"] == [0.1, 0.1] - - class TestMergePoints: """Test the merge_points function.""" From fed1b519b28d54f3920c0e20f4df71ecaaf62bac Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 7 Feb 2025 15:29:33 -0800 Subject: [PATCH 49/58] fix test --- tests/unit/satlas/test_postprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/satlas/test_postprocess.py b/tests/unit/satlas/test_postprocess.py index 78cbc897..05e20c2e 100644 --- a/tests/unit/satlas/test_postprocess.py +++ b/tests/unit/satlas/test_postprocess.py @@ -62,7 +62,7 @@ def make_task_output( with open(fname, "w") as f: json.dump(fc, f) - def two_crs_merged(self, tmp_path: pathlib.Path) -> None: + def test_two_crs_merged(self, tmp_path: pathlib.Path) -> None: """Verify that when merging across two CRS it is successful.""" proj32601 = Projection(CRS.from_epsg(32601), 10, -10) proj32602 = Projection(CRS.from_epsg(32602), 10, -10) @@ -95,10 +95,10 @@ def two_crs_merged(self, tmp_path: pathlib.Path) -> None: # And the valid patches should be merged, with one in 32601 and two in 32602. valid_patches = fc["properties"]["valid_patches"] assert len(valid_patches) == 2 - patches32601 = valid_patches[str(proj32601)] + patches32601 = valid_patches[str(proj32601.crs)] patches32601.sort() assert patches32601 == [[0, 0], [1, 1]] - patches32602 = valid_patches[str(proj32602)] + patches32602 = valid_patches[str(proj32602.crs)] patches32602.sort() assert patches32602 == [[0, 0]] From eeb63e486869104b688d0cbd85478072a8b8d825 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Fri, 7 Feb 2025 22:30:02 -0500 Subject: [PATCH 50/58] add tests for smoothing --- Dockerfile | 8 + rslp/satlas/postprocess.py | 1 + tests/unit/satlas/test_postprocess.py | 270 +++++++++++++++++++++++++- 3 files changed, 276 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 15a57267..4238b27d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,4 +27,12 @@ ENV PYTHONPATH="${PYTHONPATH}:." COPY . /opt/rslearn_projects/ # install rslp package RUN pip install --no-cache-dir /opt/rslearn_projects + +# Build Satlas smooth_point_labels_viterbi.go program. +WORKDIR /opt/rslearn_projects/rslp/satlas/scripts +RUN wget https://go.dev/dl/go1.22.12.linux-amd64.tar.gz +RUN rm -rf /usr/local/go && tar -C /usr/local -xzf go1.22.12.linux-amd64.tar.gz +ENV PATH="${PATH}:/usr/local/go/bin" +RUN go build smooth_point_labels_viterbi.go + WORKDIR /opt/rslearn_projects diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py index 850b2d5f..f0f6696d 100644 --- a/rslp/satlas/postprocess.py +++ b/rslp/satlas/postprocess.py @@ -209,6 +209,7 @@ def smooth_points( ) # nosec # Now we can upload the smoothed per-timestep files. + smoothed_upath.mkdir(parents=True, exist_ok=True) for label in labels: src_path = tmp_smoothed_dir / f"{label}.geojson" dst_path = smoothed_upath / f"{label}.geojson" diff --git a/tests/unit/satlas/test_postprocess.py b/tests/unit/satlas/test_postprocess.py index 05e20c2e..b7dfe94a 100644 --- a/tests/unit/satlas/test_postprocess.py +++ b/tests/unit/satlas/test_postprocess.py @@ -3,11 +3,13 @@ import json import pathlib +import shapely from rasterio.crs import CRS -from rslearn.utils.geometry import Projection +from rslearn.const import WGS84_PROJECTION +from rslearn.utils.geometry import Projection, STGeometry -from rslp.satlas.postprocess import merge_points -from rslp.satlas.predict_pipeline import Application +from rslp.satlas.postprocess import merge_points, smooth_points +from rslp.satlas.predict_pipeline import PATCH_SIZE, Application class TestMergePoints: @@ -104,3 +106,265 @@ def test_two_crs_merged(self, tmp_path: pathlib.Path) -> None: features = fc["features"] assert len(features) == 3 + # Make sure crs property is set correctly. + feature_projections = [feat["properties"]["projection"] for feat in features] + feature_projections.sort() + assert feature_projections == [ + str(proj32601.crs), + str(proj32601.crs), + str(proj32602.crs), + ] + + +class TestSmoothPoints: + """Test the smooth_points function.""" + + # Minimum number of timesteps for a point to be considered. + # This is used in smooth_point_labels_viterbi.go to discard regions that are only + # covered by satellite imagery for a couple months over multi-year time range. + MIN_VALID_TIMESTEPS = 8 + + def make_merge_output( + self, + fname: pathlib.Path, + geometries: list[STGeometry], + additional_valid_patches: list[tuple[Projection, int, int]] = [], + scores: list[float] | None = None, + ) -> None: + """Make a GeoJSON matching those produced by merge_points. + + This is the input to smooth_points. + + Args: + fname: the filename to write to. + geometries: list of STGeometry to include. The valid patches will be set + automatically to include all of these geometries. + additional_valid_patches: additional valid patches (besides those covered + by the geometries). + scores: optional list of scores of each geometry. If set, it should be the + same length as geometries. + """ + # Convert features to GeoJSON. + features = [] + valid_patches: dict[str, set[tuple[int, int]]] = {} + for geometry_idx, geometry in enumerate(geometries): + wgs84_geometry = geometry.to_projection(WGS84_PROJECTION) + crs_str = str(geometry.projection.crs) + score = scores[geometry_idx] if scores else 1 + features.append( + { + "type": "Feature", + "properties": { + "category": "placeholder", + "score": score, + "projection": crs_str, + "col": int(geometry.shp.x), + "row": int(geometry.shp.y), + }, + "geometry": { + "type": "Point", + "coordinates": [wgs84_geometry.shp.x, wgs84_geometry.shp.y], + }, + } + ) + if crs_str not in valid_patches: + valid_patches[crs_str] = set() + valid_patches[crs_str].add( + (int(geometry.shp.x) // PATCH_SIZE, int(geometry.shp.y) // PATCH_SIZE) + ) + + for projection, col, row in additional_valid_patches: + crs_str = str(projection.crs) + if crs_str not in valid_patches: + valid_patches[crs_str] = set() + valid_patches[crs_str].add((col, row)) + + # Make and write the FeatureCollection. + fc = { + "type": "FeatureCollection", + "features": features, + "properties": { + "valid_patches": { + crs_str: list(patch_set) + for crs_str, patch_set in valid_patches.items() + }, + }, + } + fname.parent.mkdir(parents=True, exist_ok=True) + with open(fname, "w") as f: + json.dump(fc, f) + + def test_smooth_nms(self, tmp_path: pathlib.Path) -> None: + """Verify that smoothing deletes redundant points across projections.""" + # (-174, 0) is on the border between EPSG:32601 and EPSG:32602. + # So we add it in both projections and make sure smooth_points deletes it. + wgs84_geom = STGeometry(WGS84_PROJECTION, shapely.Point(-174, 0), None) + proj32601 = Projection(CRS.from_epsg(32601), 10, -10) + proj32602 = Projection(CRS.from_epsg(32602), 10, -10) + geom32601 = wgs84_geom.to_projection(proj32601) + geom32602 = wgs84_geom.to_projection(proj32602) + + merged_path = tmp_path / "merged" + smoothed_path = tmp_path / "smooth" + + # Create the input files. + # We need MIN_VALID_TIMESTEPS files for any points to be produced. + labels = [f"0000-0{timestep}" for timestep in range(self.MIN_VALID_TIMESTEPS)] + for label in labels: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geom32601, geom32602], + scores=[0.5, 0.6], + ) + + # Run the smoothing. + smooth_points( + Application.MARINE_INFRA, + labels[-1], + str(merged_path), + str(smoothed_path), + ) + + # Verify the output. + smoothed_fname = smoothed_path / f"{labels[-1]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 1 + + def test_discard_single_timestep_positive(self, tmp_path: pathlib.Path) -> None: + """Ignore a point of it is only detected in one timestep.""" + projection = Projection(CRS.from_epsg(32601), 10, -10) + geometry1 = STGeometry(projection, shapely.Point(0, 0), None) + geometry2 = STGeometry( + projection, shapely.Point(PATCH_SIZE // 2, PATCH_SIZE // 2), None + ) + + merged_path = tmp_path / "merged" + smoothed_path = tmp_path / "smooth" + + # Create the input files. + # We need MIN_VALID_TIMESTEPS files where the patch is valid, otherwise it will + # be ignored. + # We use two points here since timesteps without any points yield no output. + # It also makes it easier to ensure that patch is marked valid. + labels = [f"0000-0{timestep}" for timestep in range(self.MIN_VALID_TIMESTEPS)] + for timestep, label in enumerate(labels): + if timestep == 4: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry1, geometry2], + ) + else: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry2], + ) + + # Run the smoothing. + smooth_points( + Application.MARINE_INFRA, + labels[-1], + str(merged_path), + str(smoothed_path), + ) + + # Verify the output. + smoothed_fname = smoothed_path / f"{labels[-1]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 1 + + def test_invalid_timesteps_ignored(self, tmp_path: pathlib.Path) -> None: + """Ensure a point keeps being predicted if its patch is not observed.""" + projection = Projection(CRS.from_epsg(32601), 10, -10) + geometry = STGeometry(projection, shapely.Point(0, 0), None) + + merged_path = tmp_path / "merged" + smoothed_path = tmp_path / "smooth" + + # Create the input files. + # In the first timesteps, we detect the point. + labels = [ + f"0000-0{timestep}" for timestep in range(self.MIN_VALID_TIMESTEPS * 2) + ] + for label in labels[0 : self.MIN_VALID_TIMESTEPS]: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry], + ) + # In the remaining timesteps, we leave the patch invalid. + for label in labels[self.MIN_VALID_TIMESTEPS :]: + self.make_merge_output( + merged_path / f"{label}.geojson", + [], + ) + + # Run the smoothing. + smooth_points( + Application.MARINE_INFRA, + labels[-1], + str(merged_path), + str(smoothed_path), + ) + + # Verify the output. + smoothed_fname = smoothed_path / f"{labels[-1]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 1 + + def test_point_removed(self, tmp_path: pathlib.Path) -> None: + """Verify that a point will be removed with enough negative observations.""" + projection = Projection(CRS.from_epsg(32601), 10, -10) + geometry1 = STGeometry(projection, shapely.Point(0, 0), None) + # As with test_discard_single_timestep_positive, we use a second point so it is + # easier to ensure the patch is valid and so that outputs are produced at all + # timesteps. Here, we will observe geometry2 at every timestep but cut + # geometry1 off halfway through. + geometry2 = STGeometry( + projection, shapely.Point(PATCH_SIZE // 2, PATCH_SIZE // 2), None + ) + + merged_path = tmp_path / "merged" + smoothed_path = tmp_path / "smooth" + + # We will have 30 positive observations and then 30 negative ones. + # This ensures enough because there is very low transition probability from + # positive to negative (since it is unlikely that e.g. a wind turbine would be + # torn down). + num_observations = 30 + + # Create the input files. + # In the first timesteps, we detect both points. + labels = [f"0000-0{timestep}" for timestep in range(num_observations * 2)] + for label in labels[0:num_observations]: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry1, geometry2], + ) + # In the remaining timesteps, we only detect geometry2. + for label in labels[num_observations:]: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry2], + ) + + # Run the smoothing. + smooth_points( + Application.MARINE_INFRA, + labels[-1], + str(merged_path), + str(smoothed_path), + ) + + # At the first timestep, we should see both points. + smoothed_fname = smoothed_path / f"{labels[0]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 2 + + # At the last timestep, we should only see geometry2. + smoothed_fname = smoothed_path / f"{labels[-1]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 1 From df3341dbbf53d7bb28bc037c871c749aecf325d6 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Mon, 10 Feb 2025 13:03:43 -0800 Subject: [PATCH 51/58] refactor prediction pipeline --- rslp/satlas/predict_pipeline.py | 182 ++++++++++++++++++-------------- 1 file changed, 100 insertions(+), 82 deletions(-) diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py index 1f8e2b56..6d874736 100644 --- a/rslp/satlas/predict_pipeline.py +++ b/rslp/satlas/predict_pipeline.py @@ -32,6 +32,30 @@ # Layers not to use when seeing which patches are valid. VALIDITY_EXCLUDE_LAYERS = ["mask", "output", "label"] +PREDICTION_GROUP = "predict" +SATLAS_MATERIALIZE_PIPELINE_ARGS = MaterializePipelineArgs( + disabled_layers=[], + # Use initial job for prepare since it involves locally caching the tile index + # and other steps that should only be performed once. + prepare_args=PrepareArgs( + apply_windows_args=ApplyWindowsArgs( + group=PREDICTION_GROUP, workers=32, use_initial_job=True + ) + ), + ingest_args=IngestArgs( + ignore_errors=False, + apply_windows_args=ApplyWindowsArgs( + group=PREDICTION_GROUP, workers=32, use_initial_job=False + ), + ), + materialize_args=MaterializeArgs( + ignore_errors=False, + apply_windows_args=ApplyWindowsArgs( + group=PREDICTION_GROUP, workers=32, use_initial_job=False + ), + ), +) + logger = get_logger(__name__) @@ -77,6 +101,76 @@ def get_output_fname( return out_fname +def merge_and_upload_points( + projection: Projection, + windows: list[Window], + out_fname: UPath, +) -> None: + """Helper function to merge and upload point data after prediction is complete. + + Args: + projection: the UTM projection that we are working in. + windows: the windows that were used for prediction. + out_fname: the filename to write the merged result. + """ + # Merge the features across the windows. + # Here we also add valid patches attribute indicating which windows (patches) + # were non-zero. This is used to distinguish a point not being detected because + # it wasn't there vs not being detected just because there was no image + # available there. + fc = None + valid_patches = [] + for window in windows: + window_output_fname = window.path / "layers" / "output" / "data.geojson" + + if not window_output_fname.exists(): + continue + + with window_output_fname.open() as f: + cur_fc = json.load(f) + + if fc is None: + fc = cur_fc + else: + fc["features"].extend(cur_fc["features"]) + + valid_patches.append( + (window.bounds[0] // PATCH_SIZE, window.bounds[1] // PATCH_SIZE) + ) + + if fc is None: + # So there was no image here. + # We still want to write an empty GeoJSON so the job is marked completed. + fc = { + "type": "FeatureCollection", + "features": [], + } + + if "properties" not in fc: + fc["properties"] = {} + fc["properties"]["valid_patches"] = { + str(projection.crs): list(valid_patches), + } + + # The object detector predicts bounding boxes but we want to make all features + # just points. + for feat in fc["features"]: + assert feat["geometry"]["type"] == "Polygon" + coords = feat["geometry"]["coordinates"][0] + xs = [coord[0] for coord in coords] + ys = [coord[1] for coord in coords] + feat["geometry"] = { + "type": "Point", + "coordinates": [ + (min(xs) + max(xs)) / 2, + (min(ys) + max(ys)) / 2, + ], + } + + with out_fname.open("w") as f: + json.dump(fc, f) + + def predict_pipeline( application: Application, projection_json: str, @@ -155,7 +249,6 @@ def predict_pipeline( # Note that bounds must be multiple of patch size. for value in bounds: assert value % PATCH_SIZE == 0 - group = "predict" tile_to_window = {} for tile_col in range(bounds[0] // PATCH_SIZE, bounds[2] // PATCH_SIZE): for tile_row in range(bounds[1] // PATCH_SIZE, bounds[3] // PATCH_SIZE): @@ -166,10 +259,10 @@ def predict_pipeline( (tile_col + 1) * PATCH_SIZE, (tile_row + 1) * PATCH_SIZE, ) - window_path = ds_path / "windows" / group / window_name + window_path = ds_path / "windows" / PREDICTION_GROUP / window_name window = Window( path=window_path, - group=group, + group=PREDICTION_GROUP, name=window_name, projection=projection, bounds=window_bounds, @@ -203,33 +296,13 @@ def predict_pipeline( # Populate the windows. logger.info("materialize dataset") - materialize_pipeline_args = MaterializePipelineArgs( - disabled_layers=[], - # Use initial job for prepare since it involves locally caching the tile index - # and other steps that should only be performed once. - prepare_args=PrepareArgs( - apply_windows_args=ApplyWindowsArgs( - group=group, workers=32, use_initial_job=True - ) - ), - ingest_args=IngestArgs( - ignore_errors=False, - apply_windows_args=ApplyWindowsArgs( - group=group, workers=32, use_initial_job=False - ), - ), - materialize_args=MaterializeArgs( - ignore_errors=False, - apply_windows_args=ApplyWindowsArgs( - group=group, workers=32, use_initial_job=False - ), - ), + materialize_dataset( + ds_path, materialize_pipeline_args=SATLAS_MATERIALIZE_PIPELINE_ARGS ) - materialize_dataset(ds_path, materialize_pipeline_args=materialize_pipeline_args) # Run the model, only if at least one window has some data. completed_fnames = ds_path.glob( - f"windows/{group}/*/layers/{image_layer_names[0]}/completed" + f"windows/{PREDICTION_GROUP}/*/layers/{image_layer_names[0]}/completed" ) if len(list(completed_fnames)) == 0: logger.info("skipping prediction since no windows seem to have data") @@ -247,62 +320,7 @@ def predict_pipeline( raise NotImplementedError else: - # Merge the features across the windows. - # Here we also add valid patches attribute indicating which windows (patches) - # were non-zero. This is used to distinguish a point not being detected because - # it wasn't there vs not being detected just because there was no image - # available there. - fc = None - valid_patches = [] - for window in tile_to_window.values(): - window_output_fname = window.path / "layers" / "output" / "data.geojson" - - if not window_output_fname.exists(): - continue - - with window_output_fname.open() as f: - cur_fc = json.load(f) - - if fc is None: - fc = cur_fc - else: - fc["features"].extend(cur_fc["features"]) - - valid_patches.append( - (window.bounds[0] // PATCH_SIZE, window.bounds[1] // PATCH_SIZE) - ) - - if fc is None: - # So there was no image here. - # We still want to write an empty GeoJSON so the job is marked completed. - fc = { - "type": "FeatureCollection", - "features": [], - } - - if "properties" not in fc: - fc["properties"] = {} - fc["properties"]["valid_patches"] = { - str(projection.crs): list(valid_patches), - } - - # The object detector predicts bounding boxes but we want to make all features - # just points. - for feat in fc["features"]: - assert feat["geometry"]["type"] == "Polygon" - coords = feat["geometry"]["coordinates"][0] - xs = [coord[0] for coord in coords] - ys = [coord[1] for coord in coords] - feat["geometry"] = { - "type": "Point", - "coordinates": [ - (min(xs) + max(xs)) / 2, - (min(ys) + max(ys)) / 2, - ], - } - - with out_fname.open("w") as f: - json.dump(fc, f) + merge_and_upload_points(projection, list(tile_to_window.values()), out_fname) class PredictTaskArgs: From 9b86ca4d1cbb61bc3dacaffc65c7b02f9478612c Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 11 Feb 2025 14:04:39 -0800 Subject: [PATCH 52/58] fix tests --- tests/conftest.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..cb1fc115 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import logging + +import pytest + +from rslp.utils.mp import init_mp + +logging.basicConfig() + + +@pytest.fixture(scope="session", autouse=True) +def always_init_mp() -> None: + init_mp() From cbf1f1e9730f3820411e62acf3c2034ebd81cf10 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 11 Feb 2025 23:34:14 -0500 Subject: [PATCH 53/58] working tests --- .github/workflows/build_test.yaml | 2 +- .../forest_loss_driver/inference/test_model_predict.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_test.yaml b/.github/workflows/build_test.yaml index f7cdccc7..5222dd4e 100644 --- a/.github/workflows/build_test.yaml +++ b/.github/workflows/build_test.yaml @@ -127,7 +127,7 @@ jobs: -e TEST_PUBSUB_PROJECT=earthsystem-dev-c3po \ -e TEST_PUBSUB_TOPIC=rslearn_projects_test_topic \ -e TEST_PUBSUB_SUBSCRIPTION=rslearn_projects_test_subscription \ - test pytest tests/integration/ --ignore tests/integration_slow/ -vv + test pytest tests/integration/ -vv - name: Clean up diff --git a/tests/integration/forest_loss_driver/inference/test_model_predict.py b/tests/integration/forest_loss_driver/inference/test_model_predict.py index f78b9c8f..c65edfbc 100644 --- a/tests/integration/forest_loss_driver/inference/test_model_predict.py +++ b/tests/integration/forest_loss_driver/inference/test_model_predict.py @@ -95,7 +95,7 @@ def test_forest_loss_driver_model_predict( output_json = json.load(f) # TODO: Ideally we would have a pydantic model for this output perhaps that we could subclass from rslearn? # Check properties except probs - tol = 0.01 + tol = 0.1 assert output_json["type"] == expected_output_json["type"] assert output_json["properties"] == expected_output_json["properties"] assert len(output_json["features"]) == len(expected_output_json["features"]) From a56c3f56c83d2647c59bcfb45c487d21d2d25a09 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 18 Feb 2025 11:52:58 -0800 Subject: [PATCH 54/58] fix broken Dockerfile (missing wget) --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 4238b27d..1cf9ac90 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ FROM pytorch/pytorch:2.5.0-cuda11.8-cudnn9-runtime@sha256:d15e9803095e462e351f097fb1f5e7cdaa4f5e855d7ff6d6f36ec4c2aa2938ea RUN apt update -RUN apt install -y libpq-dev ffmpeg libsm6 libxext6 git +RUN apt install -y libpq-dev ffmpeg libsm6 libxext6 git wget # Install rslearn. # We use git clone and then git checkout instead of git clone -b so that the user could From b23985b1800c0779e7a1953b0901e3a3875bf6d3 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Tue, 18 Feb 2025 12:32:30 -0800 Subject: [PATCH 55/58] fix import for updated planetary computer data source --- rslp/satlas/data_sources.py | 4 ++-- tests/integration/satlas/test_data_sources.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rslp/satlas/data_sources.py b/rslp/satlas/data_sources.py index a20ab3d2..0403b149 100644 --- a/rslp/satlas/data_sources.py +++ b/rslp/satlas/data_sources.py @@ -6,10 +6,10 @@ import shapely from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig, SpaceMode from rslearn.const import WGS84_PROJECTION -from rslearn.data_sources.azure_sentinel1 import Sentinel1 -from rslearn.data_sources.azure_sentinel2 import Sentinel2 as AzureSentinel2 from rslearn.data_sources.data_source import DataSource, Item from rslearn.data_sources.gcp_public_data import Sentinel2 as GcpSentinel2 +from rslearn.data_sources.planetary_computer import Sentinel1 +from rslearn.data_sources.planetary_computer import Sentinel2 as AzureSentinel2 from rslearn.data_sources.utils import match_candidate_items_to_window from rslearn.dataset import Window from rslearn.tile_stores import TileStore diff --git a/tests/integration/satlas/test_data_sources.py b/tests/integration/satlas/test_data_sources.py index 9210d26a..c35d2d56 100644 --- a/tests/integration/satlas/test_data_sources.py +++ b/tests/integration/satlas/test_data_sources.py @@ -12,9 +12,9 @@ ) from rslearn.const import WGS84_PROJECTION from rslearn.data_sources import DataSource -from rslearn.data_sources.azure_sentinel1 import Sentinel1 -from rslearn.data_sources.azure_sentinel2 import Sentinel2 as AzureSentinel2 from rslearn.data_sources.gcp_public_data import Sentinel2 as Sentinel2 +from rslearn.data_sources.planetary_computer import Sentinel1 +from rslearn.data_sources.planetary_computer import Sentinel2 as AzureSentinel2 from rslearn.utils.geometry import STGeometry from upath import UPath From 436bc00bd5e37bb9727a163b6a6ca5d25b2240e4 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 20 Feb 2025 13:28:30 -0800 Subject: [PATCH 56/58] fix tests --- .github/workflows/build_test.yaml | 8 ++++---- tests/integration/satlas/test_data_sources.py | 8 ++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build_test.yaml b/.github/workflows/build_test.yaml index ff84d09a..368905be 100644 --- a/.github/workflows/build_test.yaml +++ b/.github/workflows/build_test.yaml @@ -120,10 +120,10 @@ jobs: -e GOOGLE_APPLICATION_CREDENTIALS=/tmp/gcp-credentials.json \ -e RSLP_BUCKET=rslearn-eai \ -e RSLP_PREFIX=gs://rslearn-eai \ - -e BKT_PROJECT_ID=ai2-prior-satlas \ - -e BKT_BUCKET_NAME=satlas-explorer-data \ - -e BKT_BIGTABLE_PROJECT_ID=ai2-prior-satlas \ - -e BKT_BIGTABLE_INSTANCE_ID=satlas \ + -e BKT_PROJECT_ID=earthsystem-dev-c3po \ + -e BKT_BUCKET_NAME=rslp-tests \ + -e BKT_BIGTABLE_PROJECT_ID=earthsystem-dev-c3po \ + -e BKT_BIGTABLE_INSTANCE_ID=rslp-bigtable-test-instance-c1 \ -e TEST_PUBSUB_PROJECT=earthsystem-dev-c3po \ -e TEST_PUBSUB_TOPIC=rslearn_projects_test_topic \ -e TEST_PUBSUB_SUBSCRIPTION=rslearn_projects_test_subscription \ diff --git a/tests/integration/satlas/test_data_sources.py b/tests/integration/satlas/test_data_sources.py index c35d2d56..fe97f3c6 100644 --- a/tests/integration/satlas/test_data_sources.py +++ b/tests/integration/satlas/test_data_sources.py @@ -85,9 +85,7 @@ def apply_test(self, data_source: DataSource) -> None: def test_sentinel1(self) -> None: """Run apply_test with MonthlySentinel1.""" sentinel1 = MonthlySentinel1( - sentinel1=Sentinel1( - RasterLayerConfig(LayerType.RASTER, []), - ), + sentinel1=Sentinel1(["vv"]), period_days=PERIOD_DAYS, ) self.apply_test(sentinel1) @@ -107,9 +105,7 @@ def test_sentinel2(self, tmp_path: pathlib.Path) -> None: def test_azure_sentinel2(self) -> None: """Run apply_test with MonthlyAzureSentinel2.""" sentinel2 = MonthlyAzureSentinel2( - sentinel2=AzureSentinel2( - RasterLayerConfig(LayerType.RASTER, []), - ), + sentinel2=AzureSentinel2(["B04"]), period_days=PERIOD_DAYS, ) self.apply_test(sentinel2) From 8c713ca8458746537421b57869ee7a4090ad5780 Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Thu, 20 Feb 2025 13:28:51 -0800 Subject: [PATCH 57/58] fix tests x2 --- .github/workflows/build_test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_test.yaml b/.github/workflows/build_test.yaml index 368905be..1c3b97fb 100644 --- a/.github/workflows/build_test.yaml +++ b/.github/workflows/build_test.yaml @@ -123,7 +123,7 @@ jobs: -e BKT_PROJECT_ID=earthsystem-dev-c3po \ -e BKT_BUCKET_NAME=rslp-tests \ -e BKT_BIGTABLE_PROJECT_ID=earthsystem-dev-c3po \ - -e BKT_BIGTABLE_INSTANCE_ID=rslp-bigtable-test-instance-c1 \ + -e BKT_BIGTABLE_INSTANCE_ID=rslp-bigtable-test-instance \ -e TEST_PUBSUB_PROJECT=earthsystem-dev-c3po \ -e TEST_PUBSUB_TOPIC=rslearn_projects_test_topic \ -e TEST_PUBSUB_SUBSCRIPTION=rslearn_projects_test_subscription \ From f73761dbfcd706e1a89fd295979b0873d2c8531f Mon Sep 17 00:00:00 2001 From: Favyen Bastani Date: Mon, 24 Feb 2025 10:05:00 -0800 Subject: [PATCH 58/58] don't run bkt test in ci since bigtable instance is expensive to keep just for test --- .github/workflows/build_test.yaml | 2 ++ tests/integration/satlas/test_bkt.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/build_test.yaml b/.github/workflows/build_test.yaml index 1c3b97fb..48068eb2 100644 --- a/.github/workflows/build_test.yaml +++ b/.github/workflows/build_test.yaml @@ -109,11 +109,13 @@ jobs: - name: Run unit tests with Docker Compose run: | docker compose -f docker-compose.yaml run \ + -e CI="true" \ test pytest tests/unit/ - name: Run tests with Docker Compose run: | docker compose -f docker-compose.yaml run \ + -e CI="true" \ -e AWS_ACCESS_KEY_ID=${{ secrets.AWS_ACCESS_KEY_ID }} \ -e AWS_SECRET_ACCESS_KEY=${{ secrets.AWS_SECRET_ACCESS_KEY }} \ -v ${{env.GOOGLE_GHA_CREDS_PATH}}:/tmp/gcp-credentials.json:ro \ diff --git a/tests/integration/satlas/test_bkt.py b/tests/integration/satlas/test_bkt.py index c790ade5..97123a2a 100644 --- a/tests/integration/satlas/test_bkt.py +++ b/tests/integration/satlas/test_bkt.py @@ -1,7 +1,10 @@ """Test bkt file operations.""" +import os import pathlib +import pytest + from rslp.satlas.bkt import ( BktDownloadRequest, download_from_bkt, @@ -9,7 +12,10 @@ make_bkt, ) +RUNNING_IN_CI = os.environ.get("CI", "false").lower() == "true" + +@pytest.mark.skipif(RUNNING_IN_CI, reason="Skipping in CI environment") def test_make_and_download_bkt(tmp_path: pathlib.Path) -> None: """Test making and downloading a bkt file.