diff --git a/.github/workflows/build_test.yaml b/.github/workflows/build_test.yaml index f52bca4f..48068eb2 100644 --- a/.github/workflows/build_test.yaml +++ b/.github/workflows/build_test.yaml @@ -109,18 +109,27 @@ jobs: - name: Run unit tests with Docker Compose run: | docker compose -f docker-compose.yaml run \ + -e CI="true" \ test pytest tests/unit/ - name: Run tests with Docker Compose run: | docker compose -f docker-compose.yaml run \ + -e CI="true" \ -e AWS_ACCESS_KEY_ID=${{ secrets.AWS_ACCESS_KEY_ID }} \ -e AWS_SECRET_ACCESS_KEY=${{ secrets.AWS_SECRET_ACCESS_KEY }} \ -v ${{env.GOOGLE_GHA_CREDS_PATH}}:/tmp/gcp-credentials.json:ro \ -e GOOGLE_APPLICATION_CREDENTIALS=/tmp/gcp-credentials.json \ -e RSLP_BUCKET=rslearn-eai \ -e RSLP_PREFIX=gs://rslearn-eai \ - test pytest tests/integration/ --ignore tests/integration_slow/ -vv + -e BKT_PROJECT_ID=earthsystem-dev-c3po \ + -e BKT_BUCKET_NAME=rslp-tests \ + -e BKT_BIGTABLE_PROJECT_ID=earthsystem-dev-c3po \ + -e BKT_BIGTABLE_INSTANCE_ID=rslp-bigtable-test-instance \ + -e TEST_PUBSUB_PROJECT=earthsystem-dev-c3po \ + -e TEST_PUBSUB_TOPIC=rslearn_projects_test_topic \ + -e TEST_PUBSUB_SUBSCRIPTION=rslearn_projects_test_subscription \ + test pytest tests/integration/ -vv - name: Clean up diff --git a/Dockerfile b/Dockerfile index 15a57267..1cf9ac90 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ FROM pytorch/pytorch:2.5.0-cuda11.8-cudnn9-runtime@sha256:d15e9803095e462e351f097fb1f5e7cdaa4f5e855d7ff6d6f36ec4c2aa2938ea RUN apt update -RUN apt install -y libpq-dev ffmpeg libsm6 libxext6 git +RUN apt install -y libpq-dev ffmpeg libsm6 libxext6 git wget # Install rslearn. # We use git clone and then git checkout instead of git clone -b so that the user could @@ -27,4 +27,12 @@ ENV PYTHONPATH="${PYTHONPATH}:." COPY . /opt/rslearn_projects/ # install rslp package RUN pip install --no-cache-dir /opt/rslearn_projects + +# Build Satlas smooth_point_labels_viterbi.go program. +WORKDIR /opt/rslearn_projects/rslp/satlas/scripts +RUN wget https://go.dev/dl/go1.22.12.linux-amd64.tar.gz +RUN rm -rf /usr/local/go && tar -C /usr/local -xzf go1.22.12.linux-amd64.tar.gz +ENV PATH="${PATH}:/usr/local/go/bin" +RUN go build smooth_point_labels_viterbi.go + WORKDIR /opt/rslearn_projects diff --git a/ai2_docs/README.md b/ai2_docs/README.md index 563fe89f..2d69cf92 100644 --- a/ai2_docs/README.md +++ b/ai2_docs/README.md @@ -33,6 +33,10 @@ rslearn datasets, model checkpoints, etc. The easiest way is to create a `.env` RSLP_PREFIX=gs://rslearn-eai RSLP_WEKA_PREFIX=weka://dfive-default/rslearn-eai + BKT_PROJECT_ID=ai2-prior-satlas + BKT_BUCKET_NAME=satlas-explorer-data + BKT_BIGTABLE_PROJECT_ID=ai2-prior-satlas + BKT_BIGTABLE_INSTANCE_ID=satlas You will also need to setup GCP credentials that have access to this bucket. @@ -42,19 +46,6 @@ launcher like `rslp.launch_beaker`, then it isn't needed since the credentials a already configured as secrets on the platform, but you would need to setup your Beaker or other platform credentials to be able to launch the jobs. -TODO: update GCP/W&B to use service accounts. - -Currently, until https://github.com/allenai/rslearn/issues/33 is resolved, model config -files use S3-compatable API to access GCS rather than GCS directly. Therefore, you need -to set up environment variables to provide the appropriate credentials: - - S3_ACCESS_KEY_ID=GOOG... - S3_SECRET_ACCESS_KEY=... - -You can create these credentials at -https://console.cloud.google.com/storage/settings;tab=interoperability?hl=en&project=skylight-proto-1 -under "Access keys for your user account". - Usage ----- diff --git a/data/satlas/marine_infra/config.json b/data/satlas/marine_infra/config.json new file mode 100644 index 00000000..bbb765a5 --- /dev/null +++ b/data/satlas/marine_infra/config.json @@ -0,0 +1,111 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "format": { + "coordinate_mode": "pixel", + "name": "geojson" + }, + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + } + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "format": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + }, + "name": "geotiff" + }, + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_cloud_cover": 50, + "max_time_delta": "0d", + "modality": "L1C", + "name": "rslp.satlas.data_sources.MonthlySentinel2", + "query_config": { + "max_matches": 4 + }, + "sort_by": "cloud_cover", + "use_rtree_index": false + }, + "type": "raster" + } + }, + "tile_store": { + "class_path": "rslearn.tile_stores.default.DefaultTileStore", + "init_args": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + } + } + } +} diff --git a/data/satlas/marine_infra/config.yaml b/data/satlas/marine_infra/config.yaml new file mode 100644 index 00000000..90c50357 --- /dev/null +++ b/data/satlas/marine_infra/config.yaml @@ -0,0 +1,219 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 3 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: FLOAT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslp.satlas.train.MarineInfraTask + init_args: + property_name: "category" + classes: ["unknown", "platform", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]] + skip_unknown_categories: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + tags: + split: train + val_config: + patch_size: 512 + tags: + split: val + test_config: + patch_size: 512 + tags: + split: val + predict_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/20241210/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_marine_infra +rslp_experiment: data_20241210_run_20241210_00 diff --git a/data/satlas/solar_farm/config.json b/data/satlas/solar_farm/config.json new file mode 100644 index 00000000..cd0ba082 --- /dev/null +++ b/data/satlas/solar_farm/config.json @@ -0,0 +1,88 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "label_raster": { + "band_sets": [ + { + "bands": [ + "label" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_cloud_cover": 50, + "max_time_delta": "0d", + "modality": "L1C", + "name": "rslp.satlas.data_sources.MonthlySentinel2", + "query_config": { + "max_matches": 4 + }, + "sort_by": "cloud_cover", + "use_rtree_index": false + }, + "type": "raster" + } + } +} diff --git a/data/satlas/solar_farm/config.yaml b/data/satlas/solar_farm/config.yaml new file mode 100644 index 00000000..cbc6c71c --- /dev/null +++ b/data/satlas/solar_farm/config.yaml @@ -0,0 +1,207 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + decoders: + segment: + - class_path: rslearn.models.unet.UNetDecoder + init_args: + in_channels: [[4, 128], [8, 256], [16, 512], [32, 1024]] + out_channels: 2 + conv_layers_per_resolution: 2 + - class_path: rslearn.train.tasks.segmentation.SegmentationHead + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/solar_farm/dataset_v1/20250108/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "raster" + layers: ["label_raster"] + bands: ["label"] + dtype: INT32 + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + segment: + class_path: rslearn.train.tasks.segmentation.SegmentationTask + init_args: + num_classes: 2 + metric_kwargs: + average: "micro" + remap_values: [[0, 1], [0, 255]] + input_mapping: + segment: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 256 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image", "target/segment/classes", "target/segment/valid"] + tags: + split: train + val_config: + patch_size: 256 + tags: + split: val + test_config: + patch_size: 256 + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/solar_farm/dataset_v1/20250108/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_segment/accuracy + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_solar_farm +rslp_experiment: data_20250108_satlaspretrain_patch256_00 diff --git a/data/satlas/wind_turbine/README.md b/data/satlas/wind_turbine/README.md new file mode 100644 index 00000000..f2d4700d --- /dev/null +++ b/data/satlas/wind_turbine/README.md @@ -0,0 +1,8 @@ +`config.json` and `config.yaml` contain the active dataset and model configurations for +the wind turbine point detection model. + +It uses six monthly Sentinel-2 L1C mosaics, which can be spread over up to nine months +(in case some months don't have enough cloud-free images to form a mosaic). + +`config_azure.json` is for testing a model that inputs Sentinel-2 L2A + Sentinel-1 +vv+vh sourced from Microsoft Azure. diff --git a/data/satlas/wind_turbine/config.json b/data/satlas/wind_turbine/config.json new file mode 100644 index 00000000..da90e09f --- /dev/null +++ b/data/satlas/wind_turbine/config.json @@ -0,0 +1,73 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_cloud_cover": 50, + "max_time_delta": "0d", + "modality": "L1C", + "name": "rslp.satlas.data_sources.MonthlySentinel2", + "query_config": { + "max_matches": 6 + }, + "sort_by": "cloud_cover", + "use_rtree_index": false + }, + "type": "raster" + } + } +} diff --git a/data/satlas/wind_turbine/config.yaml b/data/satlas/wind_turbine/config.yaml new file mode 100644 index 00000000..28a0aead --- /dev/null +++ b/data/satlas/wind_turbine/config.yaml @@ -0,0 +1,238 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241002_satlaspretrainold_patch384_03 diff --git a/data/satlas/wind_turbine/config_azure.json b/data/satlas/wind_turbine/config_azure.json new file mode 100644 index 00000000..419dcbe3 --- /dev/null +++ b/data/satlas/wind_turbine/config_azure.json @@ -0,0 +1,110 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "mask": { + "band_sets": [ + { + "bands": [ + "mask" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + } + ], + "type": "raster" + }, + "output": { + "type": "vector" + }, + "sentinel1": { + "band_sets": [ + { + "bands": [ + "vv", + "vh" + ], + "dtype": "float32" + } + ], + "data_source": { + "ingest": false, + "name": "rslp.satlas.data_sources.MonthlySentinel1", + "query": { + "sar:instrument_mode": { + "eq": "IW" + }, + "sar:polarizations": { + "eq": [ + "VV", + "VH" + ] + } + }, + "query_config": { + "max_matches": 6 + } + }, + "type": "raster" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "B02", + "B03", + "B04", + "B08" + ], + "dtype": "uint16" + }, + { + "bands": [ + "B05", + "B06", + "B07", + "B8A", + "B11", + "B12" + ], + "dtype": "uint16", + "zoom_offset": -1 + }, + { + "bands": [ + "B01", + "B09", + "B10" + ], + "dtype": "uint16", + "zoom_offset": -2 + } + ], + "data_source": { + "harmonize": true, + "ingest": false, + "max_cloud_cover": 50, + "name": "rslp.satlas.data_sources.MonthlyAzureSentinel2", + "query_config": { + "max_matches": 6 + }, + "sort_by": "eo:cloud_cover" + }, + "type": "raster" + } + }, + "tile_store": { + "class_path": "rslearn.tile_stores.default.DefaultTileStore", + "init_args": { + "geotiff_options": { + "compress": "zstd", + "predictor": 2, + "zstd_level": 1 + } + } + } +} diff --git a/data/satlas/wind_turbine/config_azure.yaml b/data/satlas/wind_turbine/config_azure.yaml new file mode 100644 index 00000000..25eb5c68 --- /dev/null +++ b/data/satlas/wind_turbine/config_azure.yaml @@ -0,0 +1,271 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 11 + output_layers: [1, 3, 5, 7] + image_channels: 11 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] + ignore_prefixes: + - "backbone.backbone.backbone.features.0.0." +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2.2"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2.3"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2.4"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2.5"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + sar1: + data_type: "raster" + layers: ["sentinel1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar2: + data_type: "raster" + layers: ["sentinel1.1"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar3: + data_type: "raster" + layers: ["sentinel1.2"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar4: + data_type: "raster" + layers: ["sentinel1.3"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar5: + data_type: "raster" + layers: ["sentinel1.4"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + sar6: + data_type: "raster" + layers: ["sentinel1.5"] + bands: ["vv", "vh"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + is_target: true + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + image_bands: [2, 1, 0] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 10000 + valid_range: [0, 1] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + sar1: [] + image2: [] + sar2: [] + image3: [] + sar3: [] + image4: [] + sar4: [] + image5: [] + sar5: [] + image6: [] + sar6: [] + output_selector: image +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241212/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241212_satlaspretrainold_patch384_00 diff --git a/convert_satlas_webmercator_to_rslearn/README.md b/one_off_projects/convert_satlas_webmercator_to_rslearn/README.md similarity index 82% rename from convert_satlas_webmercator_to_rslearn/README.md rename to one_off_projects/convert_satlas_webmercator_to_rslearn/README.md index 7258f227..c7df1dc0 100644 --- a/convert_satlas_webmercator_to_rslearn/README.md +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/README.md @@ -4,6 +4,9 @@ About This project is for converting the various Satlas application training datasets (wind turbines, solar farms, marine infrastructure) to rslearn format. +The conversion has completed so it is now in one_off_projects and the code does not +need to be maintained anymore. + Wind Turbines ------------- diff --git a/convert_satlas_webmercator_to_rslearn/__init__.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/__init__.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/__init__.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/__init__.py diff --git a/convert_satlas_webmercator_to_rslearn/assign_split.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/assign_split.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/assign_split.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/assign_split.py diff --git a/convert_satlas_webmercator_to_rslearn/config.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/config.json similarity index 100% rename from convert_satlas_webmercator_to_rslearn/config.json rename to one_off_projects/convert_satlas_webmercator_to_rslearn/config.json diff --git a/convert_satlas_webmercator_to_rslearn/lib/__init__.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/lib/__init__.py similarity index 77% rename from convert_satlas_webmercator_to_rslearn/lib/__init__.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/lib/__init__.py index 17e18a83..067aa5e3 100644 --- a/convert_satlas_webmercator_to_rslearn/lib/__init__.py +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/lib/__init__.py @@ -14,7 +14,11 @@ from rasterio.crs import CRS from rslearn.const import WGS84_PROJECTION from rslearn.dataset import Window -from rslearn.utils import Projection, STGeometry, get_utm_ups_crs +from rslearn.utils.geometry import Projection, STGeometry +from rslearn.utils.get_utm_ups_crs import get_utm_ups_crs +from rslearn.utils.feature import Feature +from rslearn.utils.vector_format import GeojsonVectorFormat +from rslearn.utils.raster_format import SingleImageRasterFormat from upath import UPath src_crs = CRS.from_epsg(3857) @@ -30,7 +34,7 @@ def convert_window( time_range: tuple[datetime, datetime], dst_pixel_size: float = 10, window_name: str | None = None, -): +) -> Window: """Create an rslearn window from a multisat window with the specified properties. Args: @@ -81,7 +85,7 @@ def convert_window( int(dst_polygon.bounds[2]), int(dst_polygon.bounds[3]), ] - window_root = root_dir / "windows" / group / window_name + window_root = Window.get_window_root(root_dir, group, window_name) window = Window( path=window_root, group=group, @@ -93,7 +97,7 @@ def convert_window( window.save() # (2) Write the turbine positions. - features = [] + features: list[Feature] = [] for shp, properties in labels: # Similar to with bounds, subtract the WebMercator pixel offset between # multisat and rslearn. @@ -101,26 +105,12 @@ def convert_window( src_geom = STGeometry(src_projection, shp, None) dst_geom = src_geom.to_projection(dst_projection) - features.append( - { - "type": "Feature", - "geometry": json.loads(shapely.to_geojson(dst_geom.shp)), - "properties": properties, - } - ) - layer_dir = window_root / "layers" / "label" - label_fname = layer_dir / "data.geojson" - layer_dir.mkdir(parents=True, exist_ok=True) - with label_fname.open("w") as f: - json.dump( - { - "type": "FeatureCollection", - "features": features, - "properties": dst_projection.serialize(), - }, - f, - ) - (layer_dir / "completed").touch() + features.append(Feature(dst_geom, properties)) + + layer_name = "label" + layer_dir = window.get_layer_dir(layer_name) + GeojsonVectorFormat().encode_vector(layer_dir, dst_projection, features) + window.mark_layer_completed(layer_name) # (3) Write mask corresponding to old window projected onto new window. mask = np.zeros((bounds[3] - bounds[1], bounds[2] - bounds[0]), dtype=np.uint8) @@ -130,13 +120,9 @@ def convert_window( polygon_cols = [coord[0] - bounds[0] for coord in dst_polygon.exterior.coords] rr, cc = skimage.draw.polygon(polygon_rows, polygon_cols, shape=mask.shape) mask[rr, cc] = 255 - layer_dir = window_root / "layers" / "mask" - mask_fname = layer_dir / "mask" / "image.png" - mask_fname.parent.mkdir(parents=True, exist_ok=True) - with mask_fname.open("wb") as f: - Image.fromarray(mask).save(f) - with (mask_fname.parent / "bounds.json").open("w") as f: - json.dump(bounds, f) - (layer_dir / "completed").touch() + layer_name = "mask" + layer_dir = window.get_raster_dir(layer_name, ["mask"]) + SingleImageRasterFormat().encode_raster(layer_dir, dst_projection, bounds, mask[None, :, :]) + window.mark_layer_completed(layer_name) return window diff --git a/convert_satlas_webmercator_to_rslearn/marine_infra/__init__.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/marine_infra/__init__.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/marine_infra/__init__.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/marine_infra/__init__.py diff --git a/convert_satlas_webmercator_to_rslearn/marine_infra/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/marine_infra/convert_siv_labels.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/marine_infra/convert_siv_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/marine_infra/convert_siv_labels.py diff --git a/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/config.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/config.json similarity index 100% rename from convert_satlas_webmercator_to_rslearn/sentinel2_vessel/config.json rename to one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/config.json diff --git a/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/convert_siv_labels.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/sentinel2_vessel/convert_siv_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/convert_siv_labels.py diff --git a/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/delete_bad_images.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/delete_bad_images.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/sentinel2_vessel/delete_bad_images.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/delete_bad_images.py diff --git a/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/reformat_multisat.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/reformat_multisat.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/sentinel2_vessel/reformat_multisat.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/sentinel2_vessel/reformat_multisat.py diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py new file mode 100644 index 00000000..97044b00 --- /dev/null +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/set_single_image_metadata.py @@ -0,0 +1,33 @@ +import json +import multiprocessing +import sys + +import tqdm +from rslearn.dataset import Dataset, Window +from upath import UPath + + +def handle_window(window: Window): + image_fnames = window.path.glob("layers/*/*/image.png") + for image_fname in image_fnames: + metadata_fname = image_fname.parent / "metadata.json" + if metadata_fname.exists(): + continue + with metadata_fname.open("w") as f: + json.dump({"bounds": window.bounds}, f) + + +def set_single_image_metadata(ds_root: str, workers: int = 32): + ds_path = UPath(ds_root) + dataset = Dataset(ds_path) + windows = dataset.load_windows(show_progress=True, workers=workers) + p = multiprocessing.Pool(workers) + outputs = p.imap_unordered(handle_window, windows) + for _ in tqdm.tqdm(outputs, total=len(windows), desc="Set single image metadata"): + pass + p.close() + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + set_single_image_metadata(ds_root=sys.argv[1]) diff --git a/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py similarity index 69% rename from convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py index b449bccc..5c423fc2 100644 --- a/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/solar_farm/convert_siv_labels.py @@ -9,8 +9,10 @@ import numpy as np import rasterio.features import shapely -from PIL import Image +from upath import UPath +from rslearn.utils.vector_format import GeojsonVectorFormat +from rslearn.utils.raster_format import SingleImageRasterFormat from ..lib import convert_window db_path = "/home/ubuntu/siv_renewable/data/siv.sqlite3" @@ -36,7 +38,7 @@ ts = ts.replace(tzinfo=timezone.utc) time_range = ( ts - timedelta(days=120), - ts + timedelta(days=30), + ts + timedelta(days=60), ) db.execute( @@ -55,7 +57,7 @@ labels.append((polygon, properties)) window = convert_window( - root_dir=out_dir, + root_dir=UPath(out_dir), group=group, zoom=15, bounds=bounds, @@ -64,15 +66,18 @@ ) # Create raster version of the label. + layer_dir = window.get_layer_dir("label") + features = GeojsonVectorFormat().decode_vector(layer_dir, bounds) + shapes = [] - with window.file_api.open("layers/label/data.geojson", "r") as f: - for feat in json.load(f)["features"]: - geometry = feat["geometry"] - assert geometry["type"] == "Polygon" - geometry["coordinates"] = ( - np.array(geometry["coordinates"]) - [window.bounds[0], window.bounds[1]] - ).tolist() - shapes.append((geometry, 255)) + for feat in features: + assert feat.geometry.projection == window.projection + geometry = json.loads(shapely.to_geojson(feat.geometry.shp)) + assert geometry["type"] == "Polygon" + geometry["coordinates"] = ( + np.array(geometry["coordinates"]) - [window.bounds[0], window.bounds[1]] + ).tolist() + shapes.append((geometry, 1)) if shapes: mask = rasterio.features.rasterize( shapes, @@ -87,5 +92,7 @@ (window.bounds[3] - window.bounds[1], window.bounds[2] - window.bounds[0]), dtype=np.uint8, ) - with window.file_api.get_folder("layers/label_raster").open("image.png", "wb") as f: - Image.fromarray(mask).save(f, format="PNG") + layer_name = "label_raster" + raster_dir = window.get_raster_dir(layer_name, ["label"]) + SingleImageRasterFormat().encode_raster(raster_dir, window.projection, window.bounds, mask[None, :, :]) + window.mark_layer_completed(layer_name) diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/__init__.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/__init__.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/__init__.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/__init__.py diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py new file mode 100644 index 00000000..0e77f638 --- /dev/null +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/assign_old_splits.py @@ -0,0 +1,49 @@ +"""Assign properties matching the old splits from multisat.""" + +import json +import multiprocessing +import sys + +import tqdm +from rslearn.dataset import Dataset +from upath import UPath + +in_fnames = { + "train": "/multisat/mosaic/splits/turbine_naip_supervision/train.json", + "val": "/multisat/mosaic/splits/turbine_naip_supervision/val.json", +} + + +def process(job): + window, split = job + if "old_split" in window.options and window.options["old_split"] == split: + return + window.options["old_split"] = split + window.save() + + +def assign_split(ds_root: str, workers: int = 32): + ds_path = UPath(ds_root) + dataset = Dataset(ds_path) + windows = dataset.load_windows(show_progress=True, workers=workers) + windows_by_name = {window.name: window for window in windows} + + jobs = [] + for split, fname in in_fnames.items(): + with open(fname) as f: + for col, row in json.load(f): + expected_window_name = f"{col*512}_{row*512}" + if expected_window_name not in windows_by_name: + continue + jobs.append((windows_by_name[expected_window_name], split)) + + p = multiprocessing.Pool(workers) + outputs = p.imap_unordered(process, jobs) + for _ in tqdm.tqdm(outputs, total=len(jobs), desc="Assign old split"): + pass + p.close() + + +if __name__ == "__main__": + multiprocessing.set_start_method("forkserver") + assign_split(ds_root=sys.argv[1]) diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py new file mode 100644 index 00000000..6bbdc281 --- /dev/null +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/compare_webmercator_utm_together.py @@ -0,0 +1,28 @@ +"""Simple script to copy the images. + +Just need to run `model predict` with visualize_dir set to utm and webm respectively. +""" + +import os +import shutil + +webm_images = {} +for fname in os.listdir("webm"): + if not fname.endswith("_gt.png"): + continue + parts = fname.split("_gt.png")[0].split("_") + tile = (int(parts[0]), int(parts[1])) + webm_images[tile] = fname +utm_images = {} +for fname in os.listdir("utm"): + if not fname.endswith("_gt.png"): + continue + parts = fname.split("_gt.png")[0].split("_") + tile = (int(parts[0]) // 512, int(parts[1]) // 512) + utm_images[tile] = fname +good_keys = set(webm_images.keys()).intersection(utm_images.keys()) +for tile in good_keys: + shutil.copyfile( + f"webm/{webm_images[tile]}", f"out/{tile[0]}_{tile[1]}_webmercator.png" + ) + shutil.copyfile(f"utm/{utm_images[tile]}", f"out/{tile[0]}_{tile[1]}_utm.png") diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json similarity index 68% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/config.json rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json index 047907a7..8eec25e4 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config.json @@ -21,7 +21,8 @@ "output": { "type": "vector" }, - "sentinel2": { + "sentinel2_a": { + "alias": "sentinel2", "band_sets": [ { "bands": [ @@ -56,19 +57,21 @@ ], "data_source": { "harmonize": true, - "index_cache_dir": "/data/favyenb/rslearn_datasets_satlas/solar_farm/cache/sentinel2", + "index_cache_dir": "cache/sentinel2", "max_time_delta": "1d", "modality": "L1C", "name": "rslearn.data_sources.gcp_public_data.Sentinel2", "query_config": { - "max_matches": 3, + "max_matches": 2, "space_mode": "CONTAINS" }, - "sort_by": "cloud_cover" + "sort_by": "cloud_cover", + "use_rtree_index": false }, "type": "raster" }, - "sentinel2.1": { + "sentinel2_b": { + "alias": "sentinel2", "band_sets": [ { "bands": [ @@ -101,9 +104,24 @@ "zoom_offset": -2 } ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "1d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "query_config": { + "max_matches": 2, + "space_mode": "CONTAINS" + }, + "sort_by": "cloud_cover", + "time_offset": "-90d", + "use_rtree_index": false + }, "type": "raster" }, - "sentinel2.2": { + "sentinel2_c": { + "alias": "sentinel2", "band_sets": [ { "bands": [ @@ -136,6 +154,20 @@ "zoom_offset": -2 } ], + "data_source": { + "harmonize": true, + "index_cache_dir": "cache/sentinel2", + "max_time_delta": "1d", + "modality": "L1C", + "name": "rslearn.data_sources.gcp_public_data.Sentinel2", + "query_config": { + "max_matches": 2, + "space_mode": "CONTAINS" + }, + "sort_by": "cloud_cover", + "time_offset": "-180d", + "use_rtree_index": false + }, "type": "raster" } }, diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml similarity index 64% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml index d96ef035..d6d022bb 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config.yaml +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml @@ -26,11 +26,12 @@ model: num_channels: 128 num_classes: 2 anchor_sizes: [[32], [64], [128], [256]] - lr: 0.0001 - plateau_factor: 0.1 - plateau_patience: 10 + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 plateau_min_lr: 0 - plateau_cooldown: 0 + plateau_cooldown: 10 restore_config: restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth remap_prefixes: @@ -38,26 +39,44 @@ model: data: class_path: rslearn.train.data_module.RslearnDataModule init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ inputs: image1: data_type: "raster" - layers: ["sentinel2"] + layers: ["sentinel2_a"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 image2: data_type: "raster" - layers: ["sentinel2.1"] + layers: ["sentinel2_a.1"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 image3: data_type: "raster" - layers: ["sentinel2.2"] + layers: ["sentinel2_b"] bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] passthrough: true - dtype: INT32 + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2_b.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2_c"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2_c.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 mask: data_type: "raster" layers: ["mask"] @@ -93,43 +112,73 @@ data: num_workers: 32 default_config: transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: selections: image1: [] image2: [] image3: [] + image4: [] + image5: [] + image6: [] output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 std: 3000 - - class_path: rslp.transforms.mask.Mask - train_config: - patch_size: 256 - transforms: + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: selections: image1: [] image2: [] image3: [] + image4: [] + image5: [] + image6: [] output_selector: image - - class_path: rslearn.train.transforms.normalize.Normalize - init_args: - mean: 0 - std: 3000 - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] groups: ["label", "naip"] tags: split: train val_config: - patch_size: 256 + patch_size: 384 groups: ["label", "naip"] tags: split: val test_config: - patch_size: 256 + patch_size: 384 groups: ["label", "naip"] tags: split: val @@ -146,7 +195,7 @@ trainer: logging_interval: "epoch" - class_path: rslearn.train.prediction_writer.RslearnWriter init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ output_layer: output selector: ["detect"] - class_path: lightning.pytorch.callbacks.ModelCheckpoint @@ -155,5 +204,9 @@ trainer: save_last: true monitor: val_detect/mAP mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: debug_20240806_satlaspretrainold_patch256_noflip_satlasbands3000_3image_02 +rslp_experiment: data_20241002_satlaspretrainold_patch384_03 diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml new file mode 100644 index 00000000..1648b771 --- /dev/null +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip_oldsplit.yaml @@ -0,0 +1,212 @@ +model: + class_path: rslearn.train.lightning_module.RslearnLightningModule + init_args: + model: + class_path: rslearn.models.multitask.MultiTaskModel + init_args: + encoder: + - class_path: rslearn.models.simple_time_series.SimpleTimeSeries + init_args: + encoder: + class_path: rslearn.models.swin.Swin + init_args: + pretrained: true + input_channels: 9 + output_layers: [1, 3, 5, 7] + image_channels: 9 + - class_path: rslearn.models.fpn.Fpn + init_args: + in_channels: [128, 256, 512, 1024] + out_channels: 128 + decoders: + detect: + - class_path: rslearn.models.faster_rcnn.FasterRCNN + init_args: + downsample_factors: [4, 8, 16, 32] + num_channels: 128 + num_classes: 2 + anchor_sizes: [[32], [64], [128], [256]] + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 + plateau_min_lr: 0 + plateau_cooldown: 10 + restore_config: + restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth + remap_prefixes: + - ["backbone.backbone.backbone.", "encoder.0.encoder.model."] +data: + class_path: rslearn.train.data_module.RslearnDataModule + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + inputs: + image1: + data_type: "raster" + layers: ["sentinel2_a"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image2: + data_type: "raster" + layers: ["sentinel2_a.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image3: + data_type: "raster" + layers: ["sentinel2_b"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image4: + data_type: "raster" + layers: ["sentinel2_b.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image5: + data_type: "raster" + layers: ["sentinel2_c"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + image6: + data_type: "raster" + layers: ["sentinel2_c.1"] + bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + passthrough: true + dtype: FLOAT32 + mask: + data_type: "raster" + layers: ["mask"] + bands: ["mask"] + passthrough: true + dtype: INT32 + targets: + data_type: "vector" + layers: ["label"] + is_target: true + task: + class_path: rslearn.train.tasks.multi_task.MultiTask + init_args: + tasks: + detect: + class_path: rslearn.train.tasks.detection.DetectionTask + init_args: + property_name: "category" + classes: ["unknown", "turbine"] + box_size: 15 + remap_values: [[0, 1], [0, 255]] + exclude_by_center: true + enable_map_metric: true + enable_f1_metric: true + f1_metric_kwargs: + cmp_mode: "distance" + cmp_threshold: 15 + flatten_classes: true + input_mapping: + detect: + targets: "targets" + batch_size: 8 + num_workers: 32 + default_config: + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + train_config: + patch_size: 384 + transforms: + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 3000 + valid_range: [0, 1] + bands: [0, 1, 2] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.normalize.Normalize + init_args: + mean: 0 + std: 8160 + valid_range: [0, 1] + bands: [3, 4, 5, 6, 7, 8] + selectors: ["image1", "image2", "image3", "image4", "image5", "image6"] + - class_path: rslearn.train.transforms.concatenate.Concatenate + init_args: + selections: + image1: [] + image2: [] + image3: [] + image4: [] + image5: [] + image6: [] + output_selector: image + - class_path: rslp.transforms.mask.Mask + - class_path: rslearn.train.transforms.flip.Flip + init_args: + image_selectors: ["image"] + box_selectors: ["target/detect"] + groups: ["label", "naip"] + tags: + old_split: train + val_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + old_split: val + test_config: + patch_size: 384 + groups: ["label", "naip"] + tags: + old_split: val + predict_config: + groups: ["predict"] + load_all_patches: true + skip_targets: true + patch_size: 512 +trainer: + max_epochs: 500 + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: "epoch" + - class_path: rslearn.train.prediction_writer.RslearnWriter + init_args: + path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/20241002/ + output_layer: output + selector: ["detect"] + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + save_top_k: 1 + save_last: true + monitor: val_detect/mAP + mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 +rslp_project: satlas_wind_turbine +rslp_experiment: data_20241002_satlaspretrainold_patch384_oldsplits_03 diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_predict.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_predict.json similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/config_predict.json rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/config_predict.json diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_naip_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_naip_labels.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/convert_naip_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_naip_labels.py diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_siv_labels.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_siv_labels.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/convert_siv_labels.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/convert_siv_labels.py diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py new file mode 100644 index 00000000..2a1e3257 --- /dev/null +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/create_webmercator_rslearn_dataset.py @@ -0,0 +1,182 @@ +"""Create WebMercator version of the wind turbine dataset as rslearn dataset. + +This is just used to make sure performance is the same when training in rslearn versus +training in multisat. + +The projection is not set correctly since it's just for testing. +""" + +import argparse +import json +import multiprocessing +import os +import shutil + +import shapely +import tqdm +from rasterio.crs import CRS +from rslearn.dataset import Window +from rslearn.utils import Feature, Projection, STGeometry +from rslearn.utils.mp import star_imap_unordered +from rslearn.utils.vector_format import GeojsonVectorFormat +from upath import UPath + +BANDS = ["tci", "b05", "b06", "b07", "b08", "b11", "b12"] + +# How many images model will be trained with, so can't have examples with fewer than +# this many images. +REQUIRED_IMAGES = 4 + + +def process_example( + ds_path: UPath, + label_dir: str, + image_dir: str, + tile: tuple[int, int], + image_ids: list[str], + split: str, +): + projection = Projection(CRS.from_epsg(3857), 10, -10) + + window_name = f"{tile[0]}_{tile[1]}" + group = "default" + window_root = ds_path / "windows" / group / window_name + window = Window( + path=window_root, + group=group, + name=window_name, + projection=projection, + bounds=[0, 0, 512, 512], + time_range=None, + options={"split": split}, + ) + window.save() + + # Image layers. + for idx, image_id in enumerate(image_ids): + if idx == 0: + layer_name = "sentinel2" + else: + layer_name = f"sentinel2.{idx}" + + for band in BANDS: + src_fname = os.path.join( + image_dir, image_id, band, f"{tile[0]}_{tile[1]}.png" + ) + if band == "tci": + dst_band_name = "R_G_B" + else: + dst_band_name = band + dst_fname = ( + window_root / "layers" / layer_name / dst_band_name / "image.png" + ) + dst_fname.parent.mkdir(parents=True, exist_ok=True) + with open(src_fname, "rb") as src: + with dst_fname.open("wb") as dst: + shutil.copyfileobj(src, dst) + + (window_root / "layers" / layer_name / "completed").touch() + + # Label layer. + features = [] + with open(os.path.join(label_dir, f"{tile[0]}_{tile[1]}.json")) as f: + for x1, y1, x2, y2, category in json.load(f): + geom = STGeometry(projection, shapely.box(x1, y1, x2, y2), None) + props = dict(category=category) + features.append(Feature(geom, props)) + layer_dir = window_root / "layers" / "label" + GeojsonVectorFormat().encode_vector(layer_dir, projection, features) + (layer_dir / "completed").touch() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--label_dir", + help="The multisat label directory", + type=str, + default="/multisat/labels/renewable_infra_point_naip_supervision", + ) + parser.add_argument( + "--image_dir", + help="The multisat image directory", + type=str, + default="/multisat/mosaic/turbine_naip_supervision/", + ) + parser.add_argument( + "--ds_path", + help="The path to write output rslearn dataset", + type=str, + default="gs://rslearn-eai/datasets/wind_turbine/webmercator_dataset/20240927/", + ) + parser.add_argument( + "--train_split", + help="The JSON file containing train split", + type=str, + default="/multisat/mosaic/splits/turbine_naip_supervision/train.json", + ) + parser.add_argument( + "--workers", + help="Number of parallel workers", + type=int, + default=32, + ) + args = parser.parse_args() + + # Create map from tile to image IDs that have it available. + tile_to_image_ids = {} + for image_id in os.listdir(args.image_dir): + for fname in os.listdir(os.path.join(args.image_dir, image_id, BANDS[0])): + # Make sure the other bands exist. + bands_exist = True + for band in BANDS: + if os.path.exists(os.path.join(args.image_dir, image_id, band, fname)): + continue + bands_exist = False + break + if not bands_exist: + continue + + parts = fname.split(".")[0].split("_") + tile = (int(parts[0]), int(parts[1])) + if tile not in tile_to_image_ids: + tile_to_image_ids[tile] = [] + tile_to_image_ids[tile].append(image_id) + + # Identify tiles in train split. + train_split = set() + with open(args.train_split) as f: + for col, row in json.load(f): + train_split.add((col, row)) + + ds_path = UPath(args.ds_path) + + jobs = [] + for fname in os.listdir(args.label_dir): + parts = fname.split(".")[0].split("_") + tile = (int(parts[0]), int(parts[1])) + image_ids = tile_to_image_ids.get(tile, []) + if len(image_ids) < REQUIRED_IMAGES: + continue + + if tile in train_split: + split = "train" + else: + split = "val" + + jobs.append( + dict( + ds_path=ds_path, + label_dir=args.label_dir, + image_dir=args.image_dir, + tile=tile, + image_ids=image_ids, + split=split, + ) + ) + + p = multiprocessing.Pool(args.workers) + outputs = star_imap_unordered(p, process_example, jobs) + for _ in tqdm.tqdm(outputs, total=len(jobs)): + pass + p.close() diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/quick_vis_script.py b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/quick_vis_script.py similarity index 100% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/quick_vis_script.py rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/quick_vis_script.py diff --git a/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json new file mode 100644 index 00000000..707e5593 --- /dev/null +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.json @@ -0,0 +1,342 @@ +{ + "layers": { + "label": { + "type": "vector" + }, + "output": { + "type": "vector" + }, + "sentinel2": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b05" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b06" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b07" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b08" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b11" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b12" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + } + ], + "type": "raster" + }, + "sentinel2.1": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b05" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b06" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b07" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b08" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b11" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b12" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + } + ], + "type": "raster" + }, + "sentinel2.2": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b05" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b06" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b07" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b08" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b11" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b12" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + } + ], + "type": "raster" + }, + "sentinel2.3": { + "band_sets": [ + { + "bands": [ + "R", + "G", + "B" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b05" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b06" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b07" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b08" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + } + }, + { + "bands": [ + "b11" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + }, + { + "bands": [ + "b12" + ], + "dtype": "uint8", + "format": { + "format": "png", + "name": "single_image" + }, + "zoom_offset": -1 + } + ], + "type": "raster" + } + }, + "tile_store": { + "name": "file", + "root_dir": "tiles" + } +} diff --git a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml similarity index 76% rename from convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml rename to one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml index 4cd6ffa4..478579f4 100644 --- a/convert_satlas_webmercator_to_rslearn/wind_turbine/config_flip.yaml +++ b/one_off_projects/convert_satlas_webmercator_to_rslearn/wind_turbine/webmercator_config.yaml @@ -26,11 +26,12 @@ model: num_channels: 128 num_classes: 2 anchor_sizes: [[32], [64], [128], [256]] - lr: 0.0001 - plateau_factor: 0.1 - plateau_patience: 10 + lr: 0.00002 + plateau: true + plateau_factor: 0.2 + plateau_patience: 2 plateau_min_lr: 0 - plateau_cooldown: 0 + plateau_cooldown: 10 restore_config: restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth remap_prefixes: @@ -38,30 +39,30 @@ model: data: class_path: rslearn.train.data_module.RslearnDataModule init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ + path: gs://rslearn-eai/datasets/wind_turbine/webmercator_dataset/20240927/ inputs: image1: data_type: "raster" layers: ["sentinel2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + bands: ["R", "G", "B", "b05", "b06", "b07", "b08", "b11", "b12"] passthrough: true dtype: INT32 image2: data_type: "raster" layers: ["sentinel2.1"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + bands: ["R", "G", "B", "b05", "b06", "b07", "b08", "b11", "b12"] passthrough: true dtype: INT32 image3: data_type: "raster" layers: ["sentinel2.2"] - bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] + bands: ["R", "G", "B", "b05", "b06", "b07", "b08", "b11", "b12"] passthrough: true dtype: INT32 - mask: + image4: data_type: "raster" - layers: ["mask"] - bands: ["mask"] + layers: ["sentinel2.3"] + bands: ["R", "G", "B", "b05", "b06", "b07", "b08", "b11", "b12"] passthrough: true dtype: INT32 targets: @@ -76,7 +77,7 @@ data: class_path: rslearn.train.tasks.detection.DetectionTask init_args: property_name: "category" - classes: ["unknown", "turbine"] + classes: ["unknown", "wind_turbine"] box_size: 15 remap_values: [[0, 1], [0, 255]] exclude_by_center: true @@ -89,7 +90,7 @@ data: input_mapping: detect: targets: "targets" - batch_size: 8 + batch_size: 4 num_workers: 32 default_config: transforms: @@ -99,14 +100,14 @@ data: image1: [] image2: [] image3: [] + image4: [] output_selector: image - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 - std: 3000 - - class_path: rslp.transforms.mask.Mask + std: 255 train_config: - patch_size: 256 + patch_size: 512 transforms: - class_path: rslearn.train.transforms.concatenate.Concatenate init_args: @@ -114,50 +115,41 @@ data: image1: [] image2: [] image3: [] + image4: [] output_selector: image - class_path: rslearn.train.transforms.normalize.Normalize init_args: mean: 0 - std: 3000 - - class_path: rslp.transforms.mask.Mask + std: 255 - class_path: rslearn.train.transforms.flip.Flip init_args: image_selectors: ["image"] box_selectors: ["target/detect"] - groups: ["label", "naip"] tags: split: train val_config: - patch_size: 256 - groups: ["label", "naip"] + patch_size: 512 tags: split: val test_config: - patch_size: 256 - groups: ["label", "naip"] + patch_size: 512 tags: split: val - predict_config: - groups: ["predict"] - load_all_patches: true - skip_targets: true - patch_size: 512 trainer: max_epochs: 500 callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: "epoch" - - class_path: rslearn.train.prediction_writer.RslearnWriter - init_args: - path: gs://rslearn-eai/datasets/wind_turbine/dataset_v1/live/ - output_layer: output - selector: ["detect"] - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: save_top_k: 1 save_last: true monitor: val_detect/mAP mode: max + - class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze + init_args: + module_selector: ["model", "encoder", 0, "encoder", "model"] + unfreeze_at_epoch: 2 rslp_project: satlas_wind_turbine -rslp_experiment: debug_20240806_satlaspretrainold_patch256_noflip_satlasbands3000_3image_flip_01 +rslp_experiment: data_20240927_satlaspretrainold_patch512_flip_4image_02 diff --git a/requirements.txt b/requirements.txt index 4ab931dd..f692edbf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,12 @@ beaker-py>=1.32 fastapi>=0.115 +google-cloud-bigtable>=2.18 +google-cloud-pubsub>=2.18 interrogate>=1.7 pydantic>=2.8 pytest>=8.2 python-dotenv>=1.0 ruff>=0.7 +scikit-image>=0.23 typing-extensions>=4.11 uvicorn>=0.32 diff --git a/rslp/common/README.md b/rslp/common/README.md new file mode 100644 index 00000000..3de1c95c --- /dev/null +++ b/rslp/common/README.md @@ -0,0 +1,27 @@ +This contains infrastructure intended to be shared across several rslp projects. + + +Worker +------ + +`worker.py` provides a system for launching Beaker jobs that run workers to execute +tasks from a queue. + +Each task specifies an rslp project, pipeline (workflow), and arguments to pass to the +pipeline. The queue is implemented with Google Cloud Pub/Sub. + +First create the topic and subscription via CLI: + + gcloud pubsub topics create --project skylight-proto-1 rslp-job-queue-YOURNAME + gcloud pubsub subscriptions create --project skylight-proto-1 rslp-job-queue-YOURNAME-sub --topic rslp-job-queue-YOURNAME + +You will then need to write code that writes tasks to the topic. +See `satlas/write_jobs.py` for an example of this. + +Then you can launch the worker. To test on one machine: + + python -m rslp.main common worker skylight-proto-1 rslp-job-queue-YOURNAME-sub + +And to launch 100 workers on Beaker: + + python -m rslp.main common launch BEAKER_IMAGE_NAME skylight-proto-1 rslp-job-queue-YOURNAME-sub 100 --gpus 1 --shared_memory 256GiB diff --git a/rslp/common/__init__.py b/rslp/common/__init__.py new file mode 100644 index 00000000..f94dc097 --- /dev/null +++ b/rslp/common/__init__.py @@ -0,0 +1,10 @@ +"""Pipelines common across projects.""" + +from .worker import launch_workers, worker_pipeline +from .write_file import write_file + +workflows = { + "worker": worker_pipeline, + "launch": launch_workers, + "write_file": write_file, +} diff --git a/rslp/common/worker.py b/rslp/common/worker.py new file mode 100644 index 00000000..05623246 --- /dev/null +++ b/rslp/common/worker.py @@ -0,0 +1,303 @@ +"""Worker to process jobs in a list of jobs.""" + +import json +import os +import shutil +import signal +import sys +import tempfile +import threading +import time +import uuid +from collections.abc import Callable +from concurrent import futures +from typing import Any + +import tqdm +from beaker import ( + Beaker, + Constraints, + DataMount, + DataSource, + ExperimentSpec, + Priority, + TaskResources, +) +from google.cloud import pubsub_v1 + +from rslp.launch_beaker import BUDGET, DEFAULT_WORKSPACE +from rslp.launcher_lib import get_base_env_vars +from rslp.log_utils import get_logger +from rslp.main import run_workflow + +logger = get_logger(__name__) + +# Maximum expected duration of a job in hours. We use this to limit how long we care +# about a pending claim that hasn't completed yet. +MAX_JOB_HOURS = 4 + +# Scratch directory that jobs can use and it will be managed by this module. +SCRATCH_DIRECTORY = "/tmp/scratch" + +# Directory to store SCRATCH_DIRECTORY (via symlink) in case +# manage_scratch_dir_on_data_disk is used. +# This is because some Beaker machines have much bigger /data disk than what's +# available for ephemeral storage within the Docker container, so we need to use that +# for disk-intensive tasks to avoid running out of disk space. But we also need to make +# sure we delete everything we wrote, so worker.py manages the folder. +DATA_DISK = "/data/rslearn_projects" + + +def get_cleanup_signal_handler(tmp_dir: str) -> Callable[[int, Any], None]: + """Make a signal handler that cleans up the specified directory before exiting. + + This should be passed as the handler to signal.signal. + + Args: + tmp_dir: the directory to delete when the signal is received. + """ + + def cleanup_signal_handler(signo: int, stack_frame: Any) -> None: + logger.error(f"cleanup_signal_handler: caught signal {signo}") + shutil.rmtree(tmp_dir) + sys.exit(1) + + return cleanup_signal_handler + + +def worker_pipeline( + project_id: str, + subscription_id: str, + retries: int = 3, + retry_sleep: int = 60, + idle_timeout: int = 10, + manage_scratch_dir_on_data_disk: bool = False, +) -> None: + """Start a worker to run jobs from a Pub/Sub subscription. + + The job dict including rslp project, workflow, and arguments to pass must be + written to the topic. + + Args: + project_id: the Google Cloud project ID. + subscription_id: the Pub/Sub subscription ID. + retries: retry for this many consecutive errors before terminating. A "retry" + may run a different job than the one that originally caused failure. This + ensures workers will complete most of the jobs before they terminate due to + errors. + retry_sleep: sleep for this many seconds between retries. Sleeping helps in + case there is an error due to rate limiting. + idle_timeout: seconds before we terminate if there is no activity. + manage_scratch_dir_on_data_disk: whether to create SCRATCH_DIRECTORY on the + /data disk and manage it to ensure it is deleted in case of SIGTERM. + """ + if manage_scratch_dir_on_data_disk: + # Some tasks use SCRATCH_DIRECTORY, and if management is enabled, it means we + # should put the SCRATCH_DIRECTORY on the /data/ disk (via symlink), and that + # we must ensure it is deleted in case SIGTERM is received (i.e. if the Beaker + # job is cancelled or pre-empted. + os.makedirs(DATA_DISK, exist_ok=True) + tmp_dir_on_data_disk = tempfile.TemporaryDirectory(dir=DATA_DISK) + os.symlink(tmp_dir_on_data_disk.name, SCRATCH_DIRECTORY) + signal.signal( + signal.SIGTERM, get_cleanup_signal_handler(tmp_dir_on_data_disk.name) + ) + + subscriber = pubsub_v1.SubscriberClient() + subscription_path = subscriber.subscription_path(project_id, subscription_id) + + # Callback to run the workflow indicated in the message. + def process_message(message: pubsub_v1.subscriber.message.Message) -> None: + logger.debug("worker received message %s", message) + json_data = json.loads(message.data.decode()) + rslp_project = json_data["project"] + rslp_workflow = json_data["workflow"] + workflow_args = json_data["args"] + run_workflow(rslp_project, rslp_workflow, workflow_args) + + # Callback that wraps process_message to keep track of: + # 1. Whether a message is currently being processed. + # 2. The last time that a message finished processing. + # 3. The number of consecutive errors. If there is an error in process_message, it + # will sleep for retry_sleep unless it exceeds retries in which case we exit. + lock = threading.Lock() + is_processing = False + last_message_time = time.time() + consecutive_errors = 0 + + def callback(message: pubsub_v1.subscriber.message.Message) -> None: + nonlocal is_processing, last_message_time, consecutive_errors + try: + with lock: + is_processing = True + + process_message(message) + message.ack() + + with lock: + consecutive_errors = 0 + except Exception as e: + logger.error( + "encountered error while processing message %s: %s (%d/%d consecutive errors)", + message, + e, + consecutive_errors, + retries, + ) + with lock: + consecutive_errors += 1 + time.sleep(retry_sleep) + # Pub/Sub will catch this error and print it so we just re-raise it here. + # But in our monitoring loop below we will check for more errors than + # retries and cancel the subscription if so. + raise + finally: + with lock: + is_processing = False + last_message_time = time.time() + + # We limit to a single message at a time and a single worker since tasks should use + # all of the available CPU/GPU resources. + flow_control = pubsub_v1.types.FlowControl( + max_messages=1, + # Tasks may take several hours so we allow extending the lease for up to a day. + max_lease_duration=24 * 3600, + ) + executor = futures.ThreadPoolExecutor(max_workers=1) + scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) + streaming_pull_future = subscriber.subscribe( + subscription_path, + callback=callback, + flow_control=flow_control, + scheduler=scheduler, + ) + logger.info("worker listening for messages on %s", subscription_path) + # We use the loop below to make the worker exit if there are no more messages (by + # guessing based on the provided idle timeout). We also need to exit if there have + # been too many consecutive errors. + try: + while True: + time.sleep(idle_timeout) + + with lock: + if consecutive_errors > retries: + logger.info( + "worker exiting due to %d consecutive errors", + consecutive_errors, + ) + break + + if is_processing: + logger.debug( + "worker continuing since a message is currently being processed" + ) + continue + + time_since_last_activity = time.time() - last_message_time + if time_since_last_activity < idle_timeout: + logger.debug( + "worker continuing since time since last activity %d is less than idle timeout %d", + time_since_last_activity, + idle_timeout, + ) + continue + + logger.info("worker exiting due to idle timeout") + break + finally: + # Exit the worker process. + streaming_pull_future.cancel() + streaming_pull_future.result() + + +def launch_workers( + image_name: str, + project_id: str, + subscription_id: str, + num_workers: int, + gpus: int = 0, + shared_memory: str | None = None, + priority: Priority = Priority.low, + cluster: list[str] = ["ai2/augusta-google-1"], + manage_scratch_dir_on_data_disk: bool = False, +) -> None: + """Start workers for the prediction jobs. + + Args: + image_name: the Beaker image name to use for the jobs. + project_id: the Google Cloud project ID. + subscription_id: the Pub/Sub subscription ID. + num_workers: number of workers to launch + gpus: number of GPUs to request per worker. + shared_memory: shared memory string like "256GiB". + priority: priority to assign the Beaker jobs. + cluster: clusters to target. + manage_scratch_dir_on_data_disk: see worker_pipeline. + """ + beaker = Beaker.from_env(default_workspace=DEFAULT_WORKSPACE) + + with beaker.session(): + for _ in tqdm.tqdm(range(num_workers)): + env_vars = get_base_env_vars(use_weka_prefix=False) + + spec = ExperimentSpec.new( + budget=BUDGET, + description="worker", + beaker_image=image_name, + priority=priority, + command=["python", "-m", "rslp.main"], + arguments=[ + "common", + "worker", + project_id, + subscription_id, + "--manage_scratch_dir_on_data_disk", + str(manage_scratch_dir_on_data_disk), + ], + constraints=Constraints( + cluster=cluster, + ), + preemptible=True, + datasets=[ + DataMount( + source=DataSource(secret="RSLEARN_GCP_CREDENTIALS"), # nosec + mount_path="/etc/credentials/gcp_credentials.json", # nosec + ), + DataMount( + source=DataSource(host_path="/data"), + mount_path="/data", + ), + ], + env_vars=env_vars, + resources=TaskResources(gpu_count=gpus, shared_memory=shared_memory), + ) + unique_id = str(uuid.uuid4())[0:8] + beaker.experiment.create(f"worker_{unique_id}", spec) + + +def write_jobs( + project_id: str, + topic_id: str, + rslp_project: str, + rslp_workflow: str, + args_list: list[list[str]], +) -> None: + """Write tasks to the PubSub topic. + + Args: + project_id: the project ID for the PubSub topic. + topic_id: the topic ID for the PubSub topic. + rslp_project: the rslp project to run. + rslp_workflow: the workflow in the project to run. + args_list: list of arguments fo reach task. + """ + publisher = pubsub_v1.PublisherClient() + topic_path = publisher.topic_path(project_id, topic_id) + for args in tqdm.tqdm(args_list, desc="Writing jobs to Pub/Sub topic"): + json_data = dict( + project=rslp_project, + workflow=rslp_workflow, + args=args, + ) + data = json.dumps(json_data).encode() + publisher.publish(topic_path, data).result() diff --git a/rslp/common/write_file.py b/rslp/common/write_file.py new file mode 100644 index 00000000..47330673 --- /dev/null +++ b/rslp/common/write_file.py @@ -0,0 +1,16 @@ +"""An example workflow that just writes a file. + +This is used for testing. It needs to be under rslp/ so that it can be used as a +workflow. +""" + + +def write_file(fname: str, contents: str) -> None: + """Write the contents to the file. + + Args: + fname: the filename to write. + contents: the data to write to the file. + """ + with open(fname, "w") as f: + f.write(contents) diff --git a/rslp/launch_beaker.py b/rslp/launch_beaker.py index 2a634e09..8d39ac53 100644 --- a/rslp/launch_beaker.py +++ b/rslp/launch_beaker.py @@ -13,10 +13,6 @@ DEFAULT_WORKSPACE = "ai2/earth-systems" BUDGET = "ai2/prior" -# I should make a docker image specifc to this project -# Need to add the following functionality -# upload a specified image - def launch_job( config_path: str, @@ -104,7 +100,9 @@ def launch_job( config_path, "--autoresume=true", ], - constraints=Constraints(cluster=["ai2/jupiter-cirrascale-2"]), + constraints=Constraints( + cluster=["ai2/jupiter-cirrascale-2", "ai2/augusta-google-1"] + ), preemptible=True, datasets=[launcher_lib.create_gcp_credentials_mount()], env_vars=env_vars, diff --git a/rslp/main.py b/rslp/main.py index 701e7434..ce19fc3c 100644 --- a/rslp/main.py +++ b/rslp/main.py @@ -3,9 +3,11 @@ import argparse import importlib import sys +from datetime import datetime import dotenv import jsonargparse +import jsonargparse.typing from rslp.log_utils import get_logger from rslp.utils.mp import init_mp @@ -13,6 +15,45 @@ logger = get_logger(__name__) +def datetime_serializer(v: datetime) -> str: + """Serialize datetime for jsonargparse. + + Args: + v: the datetime object. + + Returns: + the datetime encoded to string + """ + return v.isoformat() + + +def datetime_deserializer(v: str) -> datetime: + """Deserialize datetime for jsonargparse. + + Args: + v: the encoded datetime. + + Returns: + the decoded datetime object + """ + return datetime.fromisoformat(v) + + +def run_workflow(project: str, workflow: str, args: list[str]) -> None: + """Run the specified workflow. + + Args: + project: the project that the workflow is in. This is the name of the module. + workflow: the workflow name. + args: arguments to pass to jsonargparse for running the workflow function. + """ + module = importlib.import_module(f"rslp.{project}") + workflow_fn = module.workflows[workflow] + logger.info(f"running {workflow} for {project}") + logger.info(f"args: {args}") + jsonargparse.CLI(workflow_fn, args=args) + + def main() -> None: """Main entrypoint function for rslp.""" dotenv.load_dotenv() @@ -20,14 +61,15 @@ def main() -> None: parser.add_argument("project", help="The project to execute a workflow for.") parser.add_argument("workflow", help="The name of the workflow.") args = parser.parse_args(args=sys.argv[1:3]) - - module = importlib.import_module(f"rslp.{args.project}") - workflow_fn = module.workflows[args.workflow] - logger.info(f"running {args.workflow} for {args.project}") - logger.info(f"args: {sys.argv[3:]}") - jsonargparse.CLI(workflow_fn, args=sys.argv[3:]) + run_workflow(args.project, args.workflow, sys.argv[3:]) if __name__ == "__main__": init_mp() + + # Setup jsonargparse. + jsonargparse.typing.register_type( + datetime, datetime_serializer, datetime_deserializer + ) + main() diff --git a/rslp/satlas/README.md b/rslp/satlas/README.md new file mode 100644 index 00000000..02766d68 --- /dev/null +++ b/rslp/satlas/README.md @@ -0,0 +1,124 @@ +This contains training, inference, and post-processing pipelines for the models served +at https://satlas.allen.ai/. + +## Marine Infrastructure + +The marine infrastructure model detects off-shore infrastructure in two categories: +off-shore wind turbines and off-shore platforms. The latter category essentially +includes any manmade object in the ocean that is stationary and would not normally be +considered an artificial island. + +The model inputs four Sentinel-2 images, where each image should be a mosaic that uses +Sentinel-2 scenes from a distinct month. The dataset configuration uses +`MonthlySentinel2` in `rslp/satlas/data_sources.py` to achieve this, and only uses +Sentinel-2 scenes with at most 50% cloud cover. If a given month does not have enough +matching images under the cloud threshold, then images from earlier months may be used. + +The model is meant to be run on a quarterly basis, using images up to 4 months +before the start of the quarter (giving 7 possible months to pick 4 mosaics from). If +all of the months are cloudy, then it is okay to skip that inference and come back to +it in a later season that may be less cloudy. At the same time, we don't want to limit +to just 4 months because a region may never have 4 consecutive months with cloud-free +images available. + +### Training + +The model is trained using the rslearn dataset in `gs://rslearn-eai`. See the model +configuration file for more details. + + python -m rslp.rslearn_main model fit --config data/satlas/marine_infra/config.yaml + +### Inference + +For inference, the world is split up into ~10K tiles in each of the 60 UTM zones, +yielding 600K inference tasks in total. For each task, an inference worker will: + +1. Create an rslearn dataset with a single window corresponding to the UTM zone/tile. +2. Execute the data ingestion pipeline to populate that window with Sentinel-2 images. +3. Apply the model on the window to create an output layer in the rslearn dataset. +4. Copy the contents of that output layer to a location on GCS. + +The task queue system is implemented using `rslp.common.worker`, see +`rslp/common/README.md` for details. Essentially, we first write tasks to a Google +Cloud Pub/Sub topic, and then launch workers that will read from the topic. + +Then, we start by writing the tasks: + + python -m rslp.main satlas write_jobs_for_year_months '[[2024, 7]]' MARINE_INFRA 'gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/{year:04d}-{month:02d}/' skylight-proto-1 rslp-job-queue-favyen --days_before 120 --days_after 90 + +Here: + +- `[[2024, 7]]` is a list of year-month pairs that we want to run the model on. +- MARINE_INFRA is the application we want to apply. This is an enum for "marine_infra" + and it will automatically use the dataset configuration at + `data/satlas/marine_infra/config.json` and the model configuration at + `data/satlas/marine_infra/config.yaml`. +- `gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/{year:04d}-{month:02d}/` + is the path where outputs should be written. Outputs will be named like + `EPSG:32601_65536_-524288.geojson` where `EPSG:32601` is the UTM projection, and + 65536 and -524288 are the starting column and row (respectively) of the tile. The + path should have a year and month placeholder. +- `skylight-proto-1` is the project of the Pub/Sub topic. +- `rslp-job-queue-favyen` is the name of the Pub/Sub topic. +- The inference tasks should create a window spanning 120 days before the specified + timestamp (to use images before the quarter when necessary) and 90 days after the + timestamp (corresponding to the duration of the quarter). + +Then start the workers. See `rslp/common/README.md` for details. In this example, +`rslp-job-queue-favyen-sub` should be a subscription for the topic to which the tasks +were written. Here we start 100 workers. + + python -m rslp.main common launch skylight-proto-1 rslp-job-queue-favyen-sub 100 --gpus 1 --shared_memory 256GiB + +### Post-processing. + +Post-processing for point tasks occurs locally (does not require starting jobs in parallel). + +First, merge the points computed across all of the different tasks: + + python -m rslp.main satlas merge_points MARINE_INFRA 2024-07 gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/2024-07/ gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/ + +Here: + +- MARINE_INFRA is the application we want to apply. +- 2024-07 is the timestep label. All timestep labels are YYYY-MM for the Satlas + systems. +- `gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/2024-07/` is the + folder containing inference outputs that we want to merge. +- `gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/` is the + folder to write merged outputs. The output filename will be + `gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/2024-07.geojson`. + +Second, smooth the points across timesteps. This runs a Viterbi smoothing operation. +Note that the Viterbi smoothing is implemented in a separate Go application at +`rslp/satlas/scripts/smooth_point_labels_viterbi.go`. + + python -m rslp.main satlas smooth_points MARINE_INFRA 2024-07 gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/merged/ gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/smoothed/ + +Finally, publish the outputs to Cloudflare R2. This requires +[tippecanoe](https://github.com/mapbox/tippecanoe) since it is used to generate the +vector tiles. + + python -m rslp.main satlas publish_points MARINE_INFRA gs://rslearn-eai/projects/satlas/marine_infra/version-20241212/smoothed/ 'marine-default-cluster@v4' + +## Wind Turbine + +The wind turbine model detects wind turbines on land. It is meant to be run on a +semi-annual basis, and inputs six Sentinel-2 images. As with the marine infrastructure +model, each is a mosaic using images within a 30-day period. + +Training: + + python -m rslp.rslearn_main model fit --config data/satlas/wind_turbine/config.yaml + +Inference: + + python -m rslp.main satlas write_jobs_for_year_months '[[2024, 1]]' WIND_TURBINE 'gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/{year:04d}-{month:02d}/' skylight-proto-1 rslp-job-queue-favyen --days_before 90 --days_after 181 + +Post-processing: + + python -m rslp.main satlas merge_points WIND_TURBINE 2024-01 gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/2024-01/ gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/merged/ + python -m rslp.main satlas smooth_points WIND_TURBINE 2024-01 gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/merged/ gs://rslearn-eai/projects/satlas/wind_turbine/version-20241210/smoothed/ + +Publishing for wind turbine is not supported yet since it needs to be combined with the +detected solar farms and published as "renewable energy" GeoJSON. diff --git a/rslp/satlas/__init__.py b/rslp/satlas/__init__.py new file mode 100644 index 00000000..2bb65635 --- /dev/null +++ b/rslp/satlas/__init__.py @@ -0,0 +1,23 @@ +"""Satlas batch jobs. + +Specifically, training and inference for these fine-tuned models on satlas.allen.ai: +- Marine infrastructure +- On-shore wind turbines +- Solar farms +- Tree cover +""" + +from .postprocess import merge_points, smooth_points +from .predict_pipeline import predict_multi, predict_pipeline +from .publish import publish_points +from .write_jobs import write_jobs, write_jobs_for_year_months + +workflows = { + "predict": predict_pipeline, + "predict_multi": predict_multi, + "write_jobs": write_jobs, + "write_jobs_for_year_months": write_jobs_for_year_months, + "merge_points": merge_points, + "smooth_points": smooth_points, + "publish_points": publish_points, +} diff --git a/rslp/satlas/bkt.py b/rslp/satlas/bkt.py new file mode 100644 index 00000000..b33f8977 --- /dev/null +++ b/rslp/satlas/bkt.py @@ -0,0 +1,540 @@ +"""Manage bucket files on GCS. + +We bucket together small (10-200 KB) files at high zoom levels (e.g. zoom 13) into a +single file at a lower zoom level (e.g. zoom 9) to save on GCS insert fee. + +This is similar to https://github.com/mactrem/com-tiles. + +The .bkt is just a concatenation of the small files, which here we call items. This is +specialized for storing tiles, so each item falls on a grid at a particular zoom level +and is associated with a column and row. The bkt itself is on a grid at a zoom level +equal to or lower than that of its items, i.e., a coarser grid. + +The bkt should contain every item (on the finer grid) that is contained within its tile +(on the coarser grid); if an item is missing, that should mean that there is just no +data there. + +Then there will actually be a set of bkt files at the coarser zoom level to cover the +entire region. + +We record the columns, rows, and byte offsets of items in a Google Cloud Bigtable +database. Readers can first query Bigtable, then make a range read request to GCS. +""" + +import functools +import io +import multiprocessing.pool +import os +import struct +import time +from collections.abc import Generator +from dataclasses import dataclass +from enum import Enum +from typing import Any + +import google.cloud.bigtable.row +import google.cloud.bigtable.row_filters +import google.cloud.bigtable.table +import numpy.typing as npt +import skimage.io +from google.cloud import bigtable, storage +from rslearn.utils.mp import star_imap_unordered + +from rslp.log_utils import get_logger + +logger = get_logger(__name__) + +# Number of retries for reading from BigTable, since sometimes there are transient +# errors. +BIGTABLE_RETRIES = 8 + + +@dataclass +class BktItemMetadata: + """Metadata about an item (small file) stored within the bkt.""" + + # The column and row of the item on the item (fine-grained) grid. + col: int + row: int + + # The byte offset and length of the item within the concatenated bkt file. + offset: int + length: int + + def pack(self) -> bytes: + """Pack the metadata into bytes.""" + return struct.pack(">IIII", self.col, self.row, self.offset, self.length) + + @staticmethod + def unpack(b: bytes) -> "BktItemMetadata": + """Unpack a BktItemMetadata from a 16-byte string.""" + col, row, offset, length = struct.unpack(">IIII", b) + return BktItemMetadata(col, row, offset, length) + + +class BktInserter: + """A helper class that stores metadata about bkt files so it can be inserted later. + + The BktInserter is a separate class from BktWriter so that it can be pickled to + support use with multiprocessing. + + Normal usage is to create BktWriter in parallel, call get_inserter to get the + BktInserter objects, then collect them in the main thread and finally call + BktInserter.run. This way the worker threads do not need to make additional + connections to Bigtable, which really becomes a problem. + """ + + def __init__( + self, + item_metadatas: list[BktItemMetadata], + bkt_fname: str, + bkt_zoom: int, + zoom: int, + ): + """Create a new BktInserter. + + It stores information that will be needed to insert into Bigtable. + + Args: + item_metadatas: metadata about items within the bkt. + bkt_fname: the filename where the bkt will be written. + bkt_zoom: the zoom level of the bkt. + zoom: the zoom level of the tiles within the bkt. + """ + self.item_metadatas = item_metadatas + self.bkt_fname = bkt_fname + self.bkt_zoom = bkt_zoom + self.zoom = zoom + + def run(self, bkt_files_table: google.cloud.bigtable.table.Table) -> None: + """Insert the metadata into BigTable. + + Args: + bkt_files_table: the BigTable object + """ + # Row key in the table is just the bkt fname. + # Value is [4 byte bkt_zoom][4 byte zoom][indexes]. + # [indexes] is list of indexes encoded as [4 byte col][4 byte row][4 byte offset][4 byte length]. + buf = io.BytesIO() + buf.write(struct.pack(">II", self.bkt_zoom, self.zoom)) + for item_metadata in self.item_metadatas: + buf.write(item_metadata.pack()) + db_row = bkt_files_table.direct_row(self.bkt_fname) + db_row.set_cell(b"d", b"d", buf.getvalue()) + db_row.commit() + + +class BktWriter: + """Writer for bkt files. + + Call add to add one item at a time. Then call get_bytes and write the data to GCS. + Finally call insert (or use get_inserter and then pass to main thread and call run + on the BktInserter object). + + Callers must write to GCS before Bigtable so that when clients read from Bigtable + they can expect the files to already be written. + + upload_bkts can help with inserting multiple BktWriters, e.g.: + + bkt_writers = {} + for col, row in item_tiles: + # bkt_factor is 2^(item zoom - bkt zoom), i.e. the difference in scale + # between the item grid and the bkt grid. + bkt_tile = (col//bkt_factor, row//bkt_factor) + if bkt_tile not in bkt_writers: + bkt_writers[bkt_tile] = bkt.BktWriter() + contents = ... + bkt_writers[bkt_tile].add(col, row, contents) + + bkt_uploads = [] + for bkt_tile, bkt_writer in bkt_writers.items(): + out_fname = '.../{}/{}/{}.bkt'.format(bkt_zoom, bkt_tile[0], out_tile[1]) + bkt_uploads.append((bkt_writer, out_fname, args.out_zoom, args.zoom)) + # p is a multiprocessing.Pool. + bkt.upload_bkts(bkt_files_table, p, bkt_uploads) + """ + + def __init__(self) -> None: + """Create a new BktWriter.""" + self.item_metadatas: list[BktItemMetadata] = [] + self.buf = io.BytesIO() + self.offset = 0 + + def add(self, col: int, row: int, bytes: bytes) -> None: + """Add a file to the bkt. + + Args: + col: the tile column. + row: the tile row. + bytes: the data at this tile. + """ + offset = self.offset + length = len(bytes) + self.item_metadatas.append(BktItemMetadata(col, row, offset, length)) + self.buf.write(bytes) + self.offset += length + + def get_bytes(self) -> bytes: + """Returns the bytes of the whole bkt file. + + This is what should be uploaded to GCS. + """ + return self.buf.getvalue() + + def get_inserter(self, bkt_fname: str, bkt_zoom: int, zoom: int) -> "BktInserter": + """Creates a BktInserter that manages inserting the byte offsets to BigTable. + + Args: + bkt_fname: the filename where the bkt will be written. + bkt_zoom: the zoom level of the bkt file. + zoom: the zoom of the tiles within the bkt file. + + Returns: + a corresponding BktInserter + """ + return BktInserter(self.item_metadatas, bkt_fname, bkt_zoom, zoom) + + def insert( + self, + bkt_files_table: google.cloud.bigtable.table.Table, + bkt_fname: str, + bkt_zoom: int, + zoom: int, + ) -> None: + """Insert the byte offsets for this bkt to BigTable. + + Args: + bkt_files_table: the BigTable table object. + bkt_fname: the filename where the bkt will be written. + bkt_zoom: the zoom level of the bkt file. + zoom: the zoom of the tiles within the bkt file. + """ + self.get_inserter(bkt_fname, bkt_zoom, zoom).run(bkt_files_table) + + +@functools.cache +def get_bucket() -> storage.Bucket: + """Get the GCS bucket where bkt files should be stored. + + This comes from the environment variables: + - BKT_PROJECT_ID: GCP project + - BKT_BUCKET_NAME: the GCS bucket within that project. + """ + storage_client = storage.Client(project=os.environ["BKT_PROJECT_ID"]) + bucket = storage_client.bucket(os.environ["BKT_BUCKET_NAME"]) + return bucket + + +@functools.cache +def get_bigtable() -> google.cloud.bigtable.table.Table: + """Get the BigTable table storing bkt metadata.""" + bigtable_client = bigtable.Client(project=os.environ["BKT_BIGTABLE_PROJECT_ID"]) + bigtable_instance = bigtable_client.instance(os.environ["BKT_BIGTABLE_INSTANCE_ID"]) + bkt_files_table = bigtable_instance.table("bkt_files") + return bkt_files_table + + +class DecodeMode(str, Enum): + """Mode indicating how items should be decoded when downloading in parallel. + + This is used in functions like download_bkts so that the worker processes can + handle decoding rather than the caller needing to decode in the main thread. + """ + + # Decode it from image bytes to numpy array. + IMAGE = "image" + + # Yield the bytes directly. + RAW = "raw" + + +@dataclass +class BktDownloadRequest: + """A request to download an item in a bkt file to pass to download_bkts.""" + + # Name of bkt file for this job. + bkt_fname: str + + # Column and row on item (fine-grained) grid to read. + col: int + row: int + + # Arbitrary metadata for use by caller. + # It will be returned with the decoded item data. + metadata: Any = None + + +def _download_bkt( + bkt_fname: str, + item_metadatas: list[BktItemMetadata], + wanted: list[BktDownloadRequest], + decode_mode: DecodeMode, +) -> list[tuple[Any, npt.NDArray | bytes]]: + """Download tiles in a bkt file. + + Args: + bkt_fname: the bkt filename in the bucket to download from. + item_metadatas: the item metadatas for this bkt file. + wanted: list of BktDownloadRequest that specify the tiles to download. Note + that if a tile does not exist within the bkt, it will not be returned at + all. + decode_mode: how the items should be decoded. + + Returns: + a list of (metadata, contents) where contents is a numpy array with + DecodeMode.IMAGE or a byte string with DecodeMode.RAW. + """ + bucket = get_bucket() + output = [] + + # Helper to postprocess an output based on the specified return mode. + def add_output(metadata: Any, contents: npt.NDArray | bytes) -> None: + if decode_mode == DecodeMode.IMAGE: + buf = io.BytesIO(contents) + image = skimage.io.imread(buf) + output.append((metadata, image)) + + elif decode_mode == DecodeMode.RAW: + output.append((metadata, contents)) + + else: + raise ValueError(f"invalid decode mode {decode_mode}") + + # Convert item_metadatas to a map from (col, row) -> (offset, length). + idx_map = {(m.col, m.row): (m.offset, m.length) for m in item_metadatas} + + # Filter for just the requested tiles that actually exist. + # The caller should assume that tiles that are not returned simply don't have data. + wanted = [request for request in wanted if (request.col, request.row) in idx_map] + + if len(wanted) == 1: + # If there is just one requested item within this bkt, we can do a range read + # to read only that item. + request = wanted[0] + offset, length = idx_map[(request.col, request.row)] + blob = bucket.blob(bkt_fname) + contents = blob.download_as_bytes(start=offset, end=offset + length) + add_output(request.metadata, contents) + + elif len(wanted) > 1: + # Otherwise, we read the entire bkt file and then extract the segments + # corresponding to the requested items. + blob = bucket.blob(bkt_fname) + bkt_bytes = blob.download_as_bytes() + for request in wanted: + offset, length = idx_map[(request.col, request.row)] + contents = bkt_bytes[offset : offset + length] + add_output(request.metadata, contents) + + # We return a list of (metadata, contents) from this bkt file. + # In download_from_bkt, it will combine these tuples across all of the bkt files. + return output + + +def _bkt_retry_loop( + bkt_files_table: google.cloud.bigtable.table.Table, bkt_fname: str +) -> google.cloud.bigtable.row.PartialRowData: + """Retry loop to read the bkt_fname metadata from BigTable. + + This is used because sometimes there are transient errors reading. + """ + + def attempt_read() -> google.cloud.bigtable.row.PartialRowData: + return bkt_files_table.read_row( + bkt_fname, + filter_=google.cloud.bigtable.row_filters.CellsColumnLimitFilter(1), + ) + + for _ in range(BIGTABLE_RETRIES): + try: + return attempt_read() + except Exception as e: + logger.warning( + f"got error reading bkt_files_table for {bkt_fname} (trying again): {e}" + ) + time.sleep(1) + + # One last read, if it fails then we let the exception go. + return attempt_read() + + +def download_from_bkt( + bkt_files_table: google.cloud.bigtable.table.Table, + download_requests: list[BktDownloadRequest], + pool: multiprocessing.pool.Pool | None = None, + decode_mode: DecodeMode = DecodeMode.RAW, +) -> Generator[tuple[Any, npt.NDArray | bytes], None, None]: + """Download tile contents in parallel from several bkt files. + + Args: + bkt_files_table: the BigTable table containing byte offsets. + download_requests: list of BktDownloadRequest indicating the bkt filenames to + read from along with the item tiles to read. Download requests from the + same bkt_fname will be grouped together so we don't read the same bkt + file multiple times. + pool: the multiprocessing pool to use for parallelism, or None to read in + current process. + decode_mode: how the items should be decoded. + + Yields: + the (metadata, contents) tuples across all of the jobs. Only items that exist + in the bkt files will be returned; if non-existing items are requested, + they would be skipped and caller should assume they have no data. + """ + # Get tiles to read grouped by each distinct bkt_fname. + requests_by_bkt_fname: dict[str, list[BktDownloadRequest]] = {} + for request in download_requests: + if request.bkt_fname not in requests_by_bkt_fname: + requests_by_bkt_fname[request.bkt_fname] = [] + requests_by_bkt_fname[request.bkt_fname].append(request) + + # Read from BigTable to identify the offset and length of each requested + # (col, row) item. + # We use this to populate a list of jobs (arguments to pass to _download_bkt + # helper). + bkt_jobs: list[dict[str, Any]] = [] + for bkt_fname, requests in requests_by_bkt_fname.items(): + db_row = _bkt_retry_loop(bkt_files_table, bkt_fname) + + # Ignore requested files that don't exist. + if not db_row: + continue + + # Skip 8-byte header with bkt_zoom/zoom. + encoded_indexes = db_row.cell_value("d", b"d")[8:] + + item_metadatas = [] + for i in range(0, len(encoded_indexes), 16): + item_metadatas.append(BktItemMetadata.unpack(encoded_indexes[i : i + 16])) + bkt_jobs.append( + dict( + bkt_fname=bkt_fname, + item_metadatas=item_metadatas, + wanted=requests, + decode_mode=decode_mode, + ) + ) + + if pool is None: + for job in bkt_jobs: + for metadata, image in _download_bkt(**job): + yield (metadata, image) + else: + outputs = star_imap_unordered(pool, _download_bkt, bkt_jobs) + for output in outputs: + for metadata, image in output: + yield (metadata, image) + + +def upload_bkt(bkt_fname: str, contents: bytes) -> None: + """Upload a bkt file to GCS bucket. + + This is primarily intended to be used as a helper function for multiprocessing. + + Args: + bkt_fname: the bkt filename within the bucket. + contents: the data to upload. + """ + bucket = get_bucket() + blob = bucket.blob(bkt_fname) + blob.upload_from_string(contents) + + +def upload_bkts( + bkt_files_table: google.cloud.bigtable.table.Table, + pool: multiprocessing.pool.Pool, + jobs: list[tuple[BktWriter, str, int, int]], +) -> None: + """Upload several bkt files to GCS in parallel. + + Args: + bkt_files_table: the BigTable table to store byte offsets. + pool: a multiprocessing pool for parallelism. + jobs: list of (bkt_writer, bkt_fname, bkt_zoom, zoom) tuples. bkt_writer is the + BktWriter where the bkt contents and metadata are stored. bkt_fname is the + path where the bkt should be written. bkt_zoom in the zoom level of the bkt + file. zoom is the zoom level of tiles within the bkt. + """ + # Upload. We upload first since reader will assume that anything existing in + # BigTable already exists on GCS. + upload_jobs: list[tuple[str, bytes]] = [] + for bkt_writer, bkt_fname, bkt_zoom, zoom in jobs: + upload_jobs.append((bkt_fname, bkt_writer.get_bytes())) + outputs = star_imap_unordered(pool, upload_bkt, upload_jobs) + for _ in outputs: + pass + # Now we insert the metadata. + for bkt_writer, bkt_fname, bkt_zoom, zoom in jobs: + bkt_writer.insert( + bkt_files_table=bkt_files_table, + bkt_fname=bkt_fname, + bkt_zoom=bkt_zoom, + zoom=zoom, + ) + + +def make_bkt(src_dir: str, dst_path: str) -> None: + """Make a bkt file from the specified local source directory. + + The source directory must contain files of the form zoom/col/row.ext (the extension + is ignored). + + A single bkt file is created, so the zoom level of the bkt is always 0. + + Args: + src_dir: the local directory to turn into a single bkt file. + dst_path: the bkt filename in the bkt GCS bucket to write to. It must have a + {zoom} placeholder where the zoom goes. + """ + bucket = get_bucket() + bkt_files_table = get_bigtable() + + for zoom_str in os.listdir(src_dir): + zoom_dir = os.path.join(src_dir, zoom_str) + if not os.path.isdir(zoom_dir): + continue + zoom = int(zoom_str) + logger.debug( + "make_bkt(%s, %s): start collecting files at zoom level %d", + src_dir, + dst_path, + zoom, + ) + + # Read all files at this zoom level from local path into bkt (in memory). + bkt_writer = BktWriter() + num_files = 0 + for col_str in os.listdir(zoom_dir): + col_dir = os.path.join(zoom_dir, col_str) + col = int(col_str) + for fname in os.listdir(col_dir): + row = int(fname.split(".")[0]) + num_files += 1 + with open(os.path.join(col_dir, fname), "rb") as f: + contents = f.read() + bkt_writer.add(col, row, contents) + logger.debug( + "make_bkt(%s, %s): processed %d files at zoom %d", + src_dir, + dst_path, + num_files, + zoom, + ) + + # Now upload to GCS. + bkt_fname = dst_path.format(zoom=zoom) + logger.debug( + "make_bkt(%s, %s) uploading bkt for zoom level %d to %s", + src_dir, + dst_path, + zoom, + bkt_fname, + ) + blob = bucket.blob(bkt_fname) + blob.upload_from_string(bkt_writer.get_bytes()) + bkt_writer.insert( + bkt_files_table=bkt_files_table, + bkt_fname=bkt_fname, + bkt_zoom=0, + zoom=zoom, + ) diff --git a/rslp/satlas/data_sources.py b/rslp/satlas/data_sources.py new file mode 100644 index 00000000..0403b149 --- /dev/null +++ b/rslp/satlas/data_sources.py @@ -0,0 +1,416 @@ +"""Customized data sources for Satlas models.""" + +from datetime import timedelta +from typing import Any + +import shapely +from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig, SpaceMode +from rslearn.const import WGS84_PROJECTION +from rslearn.data_sources.data_source import DataSource, Item +from rslearn.data_sources.gcp_public_data import Sentinel2 as GcpSentinel2 +from rslearn.data_sources.planetary_computer import Sentinel1 +from rslearn.data_sources.planetary_computer import Sentinel2 as AzureSentinel2 +from rslearn.data_sources.utils import match_candidate_items_to_window +from rslearn.dataset import Window +from rslearn.tile_stores import TileStore +from rslearn.utils.geometry import STGeometry +from upath import UPath + + +def _find_monthly_matches( + geometry: STGeometry, item_list: list[Item], period_days: int, max_matches: int +) -> list[list[Item]]: + """Match items to the geometry with one mosaic per period. + + We divide the time range of the geometry into shorter periods. Within each period, + we use the items corresponding to that period to create a mosaic. The returned item + groups include one group per period, starting from the most recent periods, up to + the provided max_matches. + + This is used e.g. when a model should process three mosaics, where each mosaic + should come from a different month. This gives more diversity of images, since + simply searching for the least cloudy images could result in selecting all of the + images from the same month. + + max_matches may be smaller than the total number of periods in the given time + range. In this case, we prefer to use mosaics of the most recent periods. However, + sometimes there may be no items in a period; in that case, the older periods are + used as a fallback. + + Args: + geometry: the window geometry to match items to. + item_list: the list of items. + period_days: the length of one period in days. + max_matches: the number of per-period mosaics to create. + + Returns: + the matched item groups, where each group contains items that yield a + per-period mosaic. + """ + # For each period, we create an STGeometry with modified time range matching that + # period, and use it with match_candidate_items_to_window to get a mosaic. + cur_groups: list[list[Item]] = [] + period_end = geometry.time_range[1] + while period_end > geometry.time_range[0] and len(cur_groups) < max_matches: + period_time_range = ( + period_end - timedelta(days=period_days), + period_end, + ) + period_end -= timedelta(period_days) + period_geom = STGeometry(geometry.projection, geometry.shp, period_time_range) + + # We modify the QueryConfig here since caller should be asking for + # multiple mosaics, but we just want one mosaic per period. + period_groups = match_candidate_items_to_window( + period_geom, + item_list, + QueryConfig(space_mode=SpaceMode.MOSAIC, max_matches=1), + ) + + # There should be zero on one groups depending on whether there were + # any items that matched. We keep the group if it is there. + if len(period_groups) == 0 or len(period_groups[0]) == 0: + # No matches for this period. + continue + cur_groups.append(period_groups[0]) + + # If there are not enough matching mosaics, then we eliminate all the + # matches since we aren't going to use this window then anyway. + if len(cur_groups) < max_matches: + return [] + + return cur_groups + + +class MonthlySentinel2(DataSource): + """Sentinel2 data source where each match is a mosaic from a different month. + + It looks at the geometry time range, identifies matching items within each 30-day + period, and then picks the most recent {num_matches} months that have at least one + item. + + It also imposes a scene-level cloud cover limit. + """ + + def __init__( + self, + sentinel2: GcpSentinel2, + max_cloud_cover: float | None = None, + period_days: int = 30, + ): + """Create a new MonthlySentinel2. + + Args: + sentinel2: the Sentinel2 data source to wrap. + max_cloud_cover: cloud cover limit for scenes. + period_days: create mosaics for intervals of this many days within the + geometry time range. + """ + self.sentinel2 = sentinel2 + self.max_cloud_cover = max_cloud_cover + self.period_days = period_days + + @staticmethod + def from_config(config: RasterLayerConfig, ds_path: UPath) -> "MonthlySentinel2": + """Creates a new MonthlySentinel2 instance from a configuration dictionary.""" + sentinel2 = GcpSentinel2.from_config(config, ds_path) + kwargs = {} + d = config.data_source.config_dict + for k in ["max_cloud_cover", "period_days"]: + if k not in d: + continue + kwargs[k] = d[k] + return MonthlySentinel2(sentinel2, **kwargs) + + def deserialize_item(self, serialized_item: Any) -> Item: + """Deserializes an item from JSON-decoded data.""" + return self.sentinel2.deserialize_item(serialized_item) + + def get_items( + self, geometries: list[STGeometry], query_config: QueryConfig + ) -> list[list[list[Item]]]: + """Get a list of items in the data source intersecting the given geometries. + + Args: + geometries: the spatiotemporal geometries + query_config: the query configuration + + Returns: + List of groups of items that should be retrieved for each geometry. + """ + # This only makes sense for mosaic space mode. + assert query_config.space_mode == SpaceMode.MOSAIC + + # This part is the same as in base Sentinel2 class. + wgs84_geometries = [ + geometry.to_projection(WGS84_PROJECTION) for geometry in geometries + ] + + if self.sentinel2.rtree_index: + candidates = self.sentinel2._get_candidate_items_index(wgs84_geometries) + else: + candidates = self.sentinel2._get_candidate_items_direct(wgs84_geometries) + + groups = [] + + for geometry, item_list in zip(wgs84_geometries, candidates): + item_list.sort(key=lambda item: item.cloud_cover) + + # Apply cloud cover limit. + if self.max_cloud_cover is not None: + item_list = [ + item + for item in item_list + if item.cloud_cover <= self.max_cloud_cover + ] + + cur_groups = _find_monthly_matches( + geometry=geometry, + item_list=item_list, + period_days=self.period_days, + max_matches=query_config.max_matches, + ) + groups.append(cur_groups) + + return groups + + def ingest( + self, + tile_store: TileStore, + items: list[Item], + geometries: list[list[STGeometry]], + ) -> None: + """Ingest items into the given tile store. + + Args: + tile_store: the tile store to ingest into + items: the items to ingest + geometries: a list of geometries needed for each item + """ + self.sentinel2.ingest(tile_store, items, geometries) + + +class MonthlyAzureSentinel2(DataSource): + """Similar to MonthlySentinel2 but for Sentinel-2 L2A on Azure.""" + + def __init__( + self, + sentinel2: AzureSentinel2, + period_days: int = 30, + ): + """Create a new MonthlyAzureSentinel2. + + Args: + sentinel2: the Sentinel2 data source to wrap. + period_days: create mosaics for intervals of this many days within the + geometry time range. + """ + self.sentinel2 = sentinel2 + self.period_days = period_days + + @staticmethod + def from_config( + config: RasterLayerConfig, ds_path: UPath + ) -> "MonthlyAzureSentinel2": + """Creates a new MonthlyAzureSentinel2 instance from a configuration dictionary.""" + sentinel2 = AzureSentinel2.from_config(config, ds_path) + kwargs = {} + d = config.data_source.config_dict + for k in ["period_days"]: + if k not in d: + continue + kwargs[k] = d[k] + return MonthlyAzureSentinel2(sentinel2, **kwargs) + + def deserialize_item(self, serialized_item: Any) -> Item: + """Deserializes an item from JSON-decoded data.""" + return self.sentinel2.deserialize_item(serialized_item) + + def get_items( + self, geometries: list[STGeometry], query_config: QueryConfig + ) -> list[list[list[Item]]]: + """Get a list of items in the data source intersecting the given geometries. + + Args: + geometries: the spatiotemporal geometries + query_config: the query configuration + + Returns: + List of groups of items that should be retrieved for each geometry. + """ + # This only makes sense for mosaic space mode. + assert query_config.space_mode == SpaceMode.MOSAIC + + groups = [] + for geometry in geometries: + # This part is the same as in base Sentinel2 class. + wgs84_geometry = geometry.to_projection(WGS84_PROJECTION) + result = self.sentinel2.client.search( + collections=[self.sentinel2.COLLECTION_NAME], + intersects=shapely.to_geojson(wgs84_geometry.shp), + datetime=wgs84_geometry.time_range, + query=self.sentinel2.query, + ) + stac_items = [item for item in result.item_collection()] + + if self.sentinel2.sort_by is not None: + stac_items.sort( + key=lambda stac_item: stac_item.properties[self.sentinel2.sort_by], + reverse=not self.sentinel2.sort_ascending, + ) + + candidate_items = [ + self.sentinel2._stac_item_to_item(stac_item) for stac_item in stac_items + ] + + # Now we use _find_monthly_matches. + cur_groups = _find_monthly_matches( + geometry, candidate_items, self.period_days, query_config.max_matches + ) + groups.append(cur_groups) + + return groups + + def ingest( + self, + tile_store: TileStore, + items: list[Item], + geometries: list[list[STGeometry]], + ) -> None: + """Ingest items into the given tile store. + + Args: + tile_store: the tile store to ingest into + items: the items to ingest + geometries: a list of geometries needed for each item + """ + self.sentinel2.ingest(tile_store, items, geometries) + + def materialize( + self, + window: Window, + item_groups: list[list[Item]], + layer_name: str, + layer_cfg: LayerConfig, + ) -> None: + """Materialize data for the window. + + Args: + window: the window to materialize + item_groups: the items from get_items + layer_name: the name of this layer + layer_cfg: the config of this layer + """ + self.sentinel2.materialize(window, item_groups, layer_name, layer_cfg) + + +class MonthlySentinel1(DataSource): + """Similar to MonthlySentinel2 but for Sentinel-1 on Azure.""" + + def __init__( + self, + sentinel1: Sentinel1, + period_days: int = 30, + ): + """Create a new MonthlySentinel1. + + Args: + sentinel1: the Sentinel1 data source to wrap. + period_days: create mosaics for intervals of this many days within the + geometry time range. + """ + self.sentinel1 = sentinel1 + self.period_days = period_days + + @staticmethod + def from_config(config: RasterLayerConfig, ds_path: UPath) -> "MonthlySentinel1": + """Creates a new MonthlySentinel1 instance from a configuration dictionary.""" + sentinel1 = Sentinel1.from_config(config, ds_path) + kwargs = {} + d = config.data_source.config_dict + for k in ["period_days"]: + if k not in d: + continue + kwargs[k] = d[k] + return MonthlySentinel1(sentinel1, **kwargs) + + def deserialize_item(self, serialized_item: Any) -> Item: + """Deserializes an item from JSON-decoded data.""" + return self.sentinel1.deserialize_item(serialized_item) + + def get_items( + self, geometries: list[STGeometry], query_config: QueryConfig + ) -> list[list[list[Item]]]: + """Get a list of items in the data source intersecting the given geometries. + + Args: + geometries: the spatiotemporal geometries + query_config: the query configuration + + Returns: + List of groups of items that should be retrieved for each geometry. + """ + # This only makes sense for mosaic space mode. + assert query_config.space_mode == SpaceMode.MOSAIC + + groups = [] + for geometry in geometries: + # This part is the same as in base Sentinel1 class. + wgs84_geometry = geometry.to_projection(WGS84_PROJECTION) + result = self.sentinel1.client.search( + collections=[self.sentinel1.COLLECTION_NAME], + intersects=shapely.to_geojson(wgs84_geometry.shp), + datetime=wgs84_geometry.time_range, + query=self.sentinel1.query, + ) + stac_items = [item for item in result.item_collection()] + + if self.sentinel1.sort_by is not None: + stac_items.sort( + key=lambda stac_item: stac_item.properties[self.sentinel1.sort_by], + reverse=not self.sentinel1.sort_ascending, + ) + + candidate_items = [ + self.sentinel1._stac_item_to_item(stac_item) for stac_item in stac_items + ] + + # Now we use _find_monthly_matches. + cur_groups = _find_monthly_matches( + geometry, candidate_items, self.period_days, query_config.max_matches + ) + groups.append(cur_groups) + + return groups + + def ingest( + self, + tile_store: TileStore, + items: list[Item], + geometries: list[list[STGeometry]], + ) -> None: + """Ingest items into the given tile store. + + Args: + tile_store: the tile store to ingest into + items: the items to ingest + geometries: a list of geometries needed for each item + """ + self.sentinel1.ingest(tile_store, items, geometries) + + def materialize( + self, + window: Window, + item_groups: list[list[Item]], + layer_name: str, + layer_cfg: LayerConfig, + ) -> None: + """Materialize data for the window. + + Args: + window: the window to materialize + item_groups: the items from get_items + layer_name: the name of this layer + layer_cfg: the config of this layer + """ + self.sentinel1.materialize(window, item_groups, layer_name, layer_cfg) diff --git a/rslp/satlas/postprocess.py b/rslp/satlas/postprocess.py new file mode 100644 index 00000000..f0f6696d --- /dev/null +++ b/rslp/satlas/postprocess.py @@ -0,0 +1,229 @@ +"""Postprocessing outputs from Satlas models.""" + +import json +import multiprocessing +import shutil +import subprocess # nosec +import tempfile +from typing import Any + +import shapely +import tqdm +from rslearn.const import WGS84_PROJECTION +from rslearn.utils.geometry import Projection, STGeometry +from upath import UPath + +from rslp.log_utils import get_logger + +from .predict_pipeline import Application + +# Individual Satlas applications use different category names than the global ones that +# we want to serve. We should adjust this but for now this map helps to rename the +# categories. +APP_CATEGORY_MAPS = { + Application.MARINE_INFRA: { + "platform": "offshore_platform", + "turbine": "offshore_wind_turbine", + }, + Application.WIND_TURBINE: { + "turbine": "wind_turbine", + }, +} + +logger = get_logger(__name__) + + +def _get_fc(fname: UPath) -> tuple[UPath, dict[str, Any]]: + """Read the FeatureCollection from the specified file. + + This is intended to be used as a handler for multiprocessing. + + Args: + fname: the filename to read. + + Returns: + a tuple (fname, fc) of the filename and the decoded FeatureCollection JSON. + """ + with fname.open() as f: + return fname, json.load(f) + + +def merge_points( + application: Application, + label: str, + predict_path: str, + merged_path: str, + workers: int = 32, +) -> None: + """Merge Satlas point outputs. + + This merges the outputs across different prediction tasks for this timestamp. + + Args: + application: the application. + label: YYYY-MM representation of the time range used for this prediction run. + predict_path: output path of the prediction pipeline where GeoJSONs from all + the different tasks have been written. + merged_path: folder to write merged predictions. The filename will be + YYYY-MM.geojson. + workers: number of worker processes. + """ + predict_upath = UPath(predict_path) + merged_features = [] + merged_patches: dict[str, list[tuple[int, int]]] = {} + + fnames = [fname for fname in predict_upath.iterdir() if fname.name != "index"] + p = multiprocessing.Pool(workers) + outputs = p.imap_unordered(_get_fc, fnames) + + # Get category remapping in case one is specified for this application. + category_map = APP_CATEGORY_MAPS.get(application, {}) + + # Iterate over each of the files produced by a prediction task. + # We merge both the predicted points along with the valid patches (patches + # processed by the task that had available input images). + for fname, cur_fc in tqdm.tqdm(outputs, total=len(fnames)): + logger.debug("merging points from %s", fname) + + # The projection information may be missing if there are no valid patches. + # In that case we can skip the file since it has neither valid patches that we + # need to track nor any predicted points. + if "crs" not in cur_fc["properties"]: + # Just do some sanity checks, there should be no features and no valid + # patches. + assert len(cur_fc["features"]) == 0 + patch_list = list(cur_fc["properties"]["valid_patches"].values()) + assert len(patch_list) == 1 and len(patch_list[0]) == 0 + continue + + src_projection = Projection.deserialize(cur_fc["properties"]) + crs_str = str(src_projection.crs) + + # We ultimately want to store longitude/latitude but + # smooth_point_labels_viterbi.go needs to know the projection and x/y so we + # write them as properties of the feature, while converting the geometry + # coordinates to WGS84. + for feat in cur_fc["features"]: + col, row = feat["geometry"]["coordinates"] + feat["properties"]["col"] = int(col) + feat["properties"]["row"] = int(row) + feat["properties"]["projection"] = crs_str + + src_geom = STGeometry(src_projection, shapely.Point(col, row), None) + dst_geom = src_geom.to_projection(WGS84_PROJECTION) + feat["geometry"]["coordinates"] = [dst_geom.shp.x, dst_geom.shp.y] + + category = feat["properties"]["category"] + if category in category_map: + feat["properties"]["category"] = category_map[category] + + merged_features.append(feat) + + # Merge the valid patches too, these indicate which portions of the world + # actually had image content for the current timestep. + assert len(cur_fc["properties"]["valid_patches"]) == 1 + if crs_str not in merged_patches: + merged_patches[crs_str] = [] + merged_patches[crs_str].extend(cur_fc["properties"]["valid_patches"][crs_str]) + + p.close() + + merged_upath = UPath(merged_path) + merged_upath.mkdir(parents=True, exist_ok=True) + merged_fname = merged_upath / f"{label}.geojson" + with merged_fname.open("w") as f: + json.dump( + { + "type": "FeatureCollection", + "features": merged_features, + "properties": { + "valid_patches": merged_patches, + }, + }, + f, + ) + + +def smooth_points( + application: Application, + label: str, + merged_path: str, + smoothed_path: str, +) -> None: + """Smooth the Satlas point outputs. + + It applies Viterbi smoothing that takes into account merged outputs from previous + time ranges, and uploads the results. + + Args: + application: the application. + label: YYYY-MM representation of the time range used for this prediction run. + merged_path: folder to write merged predictions. The filename will be + YYYY-MM.geojson. + smoothed_path: folder to write smoothed predictions. The filename will be + YYYY-MM.geojson. + """ + merged_upath = UPath(merged_path) + # Download the merged prediction history (ending with the one we just wrote) and + # run smoothing. + smoothed_upath = UPath(smoothed_path) + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_upath = UPath(tmp_dir) + tmp_merged_dir = tmp_upath / "merged" + tmp_smoothed_dir = tmp_upath / "smoothed" + tmp_hist_fname = tmp_upath / "history.geojson" + + tmp_merged_dir.mkdir() + tmp_smoothed_dir.mkdir() + + labels: list[str] = [] + for merged_fname in merged_upath.iterdir(): + # Get the label like 2024-01 from 2024-01.geojson. + if not merged_fname.name.endswith(".geojson"): + continue + label = merged_fname.name.split(".")[0] + + local_fname = tmp_merged_dir / merged_fname.name + with merged_fname.open("rb") as src: + with local_fname.open("wb") as dst: + shutil.copyfileobj(src, dst) + labels.append(label) + + # Sort by YYYY-MM, since the smoothing function expects us to provide all of + # the labels in temporal order. + labels.sort() + + # Smoothing is handled by a Go script. + subprocess.check_call( + [ + "rslp/satlas/scripts/smooth_point_labels_viterbi", + "--labels", + ",".join(labels), + "--fname", + (tmp_merged_dir / "LABEL.geojson").path, + "--out", + (tmp_smoothed_dir / "LABEL.geojson").path, + "--hist", + tmp_hist_fname.path, + ], + ) # nosec + + # Now we can upload the smoothed per-timestep files. + smoothed_upath.mkdir(parents=True, exist_ok=True) + for label in labels: + src_path = tmp_smoothed_dir / f"{label}.geojson" + dst_path = smoothed_upath / f"{label}.geojson" + with src_path.open("rb") as src: + with dst_path.open("wb") as dst: + shutil.copyfileobj(src, dst) + + # The smoothing also produces a history GeoJSON containing all of the points + # annotated with start/end properties indicating the first and last timesteps + # when the point was detected. (In this case, points detected over time are + # merged into a single GeoJSON feature.) So we upload that too. + # This history file is the one that used to create vector tiles for the web + # application. + dst_path = smoothed_upath / "history.geojson" + with tmp_hist_fname.open("rb") as src: + with dst_path.open("wb") as dst: + shutil.copyfileobj(src, dst) diff --git a/rslp/satlas/predict_pipeline.py b/rslp/satlas/predict_pipeline.py new file mode 100644 index 00000000..6d874736 --- /dev/null +++ b/rslp/satlas/predict_pipeline.py @@ -0,0 +1,386 @@ +"""Prediction pipeline for Satlas models.""" + +import json +import os +import shutil +import tempfile +from datetime import datetime +from enum import Enum +from typing import Any + +from rslearn.const import WGS84_PROJECTION +from rslearn.dataset import Window +from rslearn.utils.geometry import PixelBounds, Projection +from upath import UPath + +from rslp.log_utils import get_logger +from rslp.utils.rslearn import ( + ApplyWindowsArgs, + IngestArgs, + MaterializeArgs, + MaterializePipelineArgs, + PrepareArgs, + materialize_dataset, + run_model_predict, +) + +DATASET_CONFIG_FNAME = "data/satlas/{application}/config.json" +MODEL_CONFIG_FNAME = "data/satlas/{application}/config.yaml" +SENTINEL2_LAYER = "sentinel2" +PATCH_SIZE = 2048 + +# Layers not to use when seeing which patches are valid. +VALIDITY_EXCLUDE_LAYERS = ["mask", "output", "label"] + +PREDICTION_GROUP = "predict" +SATLAS_MATERIALIZE_PIPELINE_ARGS = MaterializePipelineArgs( + disabled_layers=[], + # Use initial job for prepare since it involves locally caching the tile index + # and other steps that should only be performed once. + prepare_args=PrepareArgs( + apply_windows_args=ApplyWindowsArgs( + group=PREDICTION_GROUP, workers=32, use_initial_job=True + ) + ), + ingest_args=IngestArgs( + ignore_errors=False, + apply_windows_args=ApplyWindowsArgs( + group=PREDICTION_GROUP, workers=32, use_initial_job=False + ), + ), + materialize_args=MaterializeArgs( + ignore_errors=False, + apply_windows_args=ApplyWindowsArgs( + group=PREDICTION_GROUP, workers=32, use_initial_job=False + ), + ), +) + +logger = get_logger(__name__) + + +class Application(Enum): + """Specifies the various Satlas applications.""" + + SOLAR_FARM = "solar_farm" + WIND_TURBINE = "wind_turbine" + MARINE_INFRA = "marine_infra" + TREE_COVER = "tree_cover" + + +APP_IS_RASTER = { + Application.SOLAR_FARM: True, + Application.WIND_TURBINE: False, + Application.MARINE_INFRA: False, + Application.SOLAR_FARM: True, +} + + +def get_output_fname( + application: Application, out_path: str, projection: Projection, bounds: PixelBounds +) -> UPath: + """Get output filename to use for this application and task. + + Args: + application: the application. + out_path: the output path. + projection: the projection of this task. + bounds: the bounds of this task. + + Returns: + the output filename. + """ + if APP_IS_RASTER[application]: + out_fname = ( + UPath(out_path) / f"{str(projection.crs)}_{bounds[0]}_{bounds[1]}.tif" + ) + else: + out_fname = ( + UPath(out_path) / f"{str(projection.crs)}_{bounds[0]}_{bounds[1]}.geojson" + ) + return out_fname + + +def merge_and_upload_points( + projection: Projection, + windows: list[Window], + out_fname: UPath, +) -> None: + """Helper function to merge and upload point data after prediction is complete. + + Args: + projection: the UTM projection that we are working in. + windows: the windows that were used for prediction. + out_fname: the filename to write the merged result. + """ + # Merge the features across the windows. + # Here we also add valid patches attribute indicating which windows (patches) + # were non-zero. This is used to distinguish a point not being detected because + # it wasn't there vs not being detected just because there was no image + # available there. + fc = None + valid_patches = [] + for window in windows: + window_output_fname = window.path / "layers" / "output" / "data.geojson" + + if not window_output_fname.exists(): + continue + + with window_output_fname.open() as f: + cur_fc = json.load(f) + + if fc is None: + fc = cur_fc + else: + fc["features"].extend(cur_fc["features"]) + + valid_patches.append( + (window.bounds[0] // PATCH_SIZE, window.bounds[1] // PATCH_SIZE) + ) + + if fc is None: + # So there was no image here. + # We still want to write an empty GeoJSON so the job is marked completed. + fc = { + "type": "FeatureCollection", + "features": [], + } + + if "properties" not in fc: + fc["properties"] = {} + fc["properties"]["valid_patches"] = { + str(projection.crs): list(valid_patches), + } + + # The object detector predicts bounding boxes but we want to make all features + # just points. + for feat in fc["features"]: + assert feat["geometry"]["type"] == "Polygon" + coords = feat["geometry"]["coordinates"][0] + xs = [coord[0] for coord in coords] + ys = [coord[1] for coord in coords] + feat["geometry"] = { + "type": "Point", + "coordinates": [ + (min(xs) + max(xs)) / 2, + (min(ys) + max(ys)) / 2, + ], + } + + with out_fname.open("w") as f: + json.dump(fc, f) + + +def predict_pipeline( + application: Application, + projection_json: str, + bounds: PixelBounds, + time_range: tuple[datetime, datetime], + out_path: str, + scratch_path: str, + use_rtree_index: bool = True, +) -> None: + """Compute outputs of a Satlas model on this tile. + + The tile is one part of a UTM zone. + + Args: + application: the application for which to compute outputs. + projection_json: JSON-encoded projection, normally a UTM zone with 10 m/pixel + resolution. + bounds: pixel coordinates within the projection on which to compute outputs. + time_range: time range to apply model on. + out_path: directory to write the outputs. It will either be a GeoTIFF or + GeoJSON, named based on the bounds. + scratch_path: where to store the dataset. + use_rtree_index: whether to prepare using rtree index. This is recommended when + applying the model globally but can be disabled for small regions to avoid + the time to create the index. + """ + dataset_config_fname = DATASET_CONFIG_FNAME.format(application=application.value) + model_config_fname = MODEL_CONFIG_FNAME.format(application=application.value) + + # Check if the output was already computed. + projection = Projection.deserialize(json.loads(projection_json)) + out_fname = get_output_fname(application, out_path, projection, bounds) + if out_fname.exists(): + logger.info(f"output file {out_fname} already exists") + return + + # Initialize an rslearn dataset. + ds_path = UPath(scratch_path) + ds_path.mkdir(parents=True) + with open(dataset_config_fname) as f: + ds_cfg = json.load(f) + + # Set the time range to use for the rtree. + # And also make sure the rtree will be cached based on the out_path. + index_cache_dir = ds_path / "index_cache_dir" + index_cache_dir.mkdir() + image_layer_names = [] + for layer_name, layer_cfg in ds_cfg["layers"].items(): + if "data_source" not in layer_cfg: + continue + layer_source_cfg = layer_cfg["data_source"] + if not layer_source_cfg["name"].endswith("MonthlySentinel2"): + continue + layer_source_cfg["index_cache_dir"] = str(index_cache_dir) + layer_source_cfg["rtree_cache_dir"] = str(UPath(out_path) / "index") + layer_source_cfg["use_rtree_index"] = use_rtree_index + layer_source_cfg["rtree_time_range"] = [ + time_range[0].isoformat(), + time_range[1].isoformat(), + ] + image_layer_names.append(layer_name) + + with (ds_path / "config.json").open("w") as f: + json.dump(ds_cfg, f) + + # Create windows corresponding to the specified projection and bounds. + # Each window is PATCH_SIZE x PATCH_SIZE, we create multiple of smaller patch size + # than the bounds instead of one big window for better parallelism and memory + # usage. + # It also helps with creating the mosaic -- depending on the dataset configuration, + # if there are portions of large window that are not fully covered by scenes, then + # only one mosaic layer would be created. (TODO: actually that seems like an issue + # with the match_candidate_items_to_window logic. Maybe we should build all the + # mosaics simultaneously instead of discarding scenes that don't match with the + # current mosaic.) + # Note that bounds must be multiple of patch size. + for value in bounds: + assert value % PATCH_SIZE == 0 + tile_to_window = {} + for tile_col in range(bounds[0] // PATCH_SIZE, bounds[2] // PATCH_SIZE): + for tile_row in range(bounds[1] // PATCH_SIZE, bounds[3] // PATCH_SIZE): + window_name = f"{tile_col}_{tile_row}" + window_bounds = ( + tile_col * PATCH_SIZE, + tile_row * PATCH_SIZE, + (tile_col + 1) * PATCH_SIZE, + (tile_row + 1) * PATCH_SIZE, + ) + window_path = ds_path / "windows" / PREDICTION_GROUP / window_name + window = Window( + path=window_path, + group=PREDICTION_GROUP, + name=window_name, + projection=projection, + bounds=window_bounds, + time_range=time_range, + ) + + # Skip if the window is too close to 0 longitude. + # Or if it crosses it. + epsilon = 1e-4 + wgs84_geom = window.get_geometry().to_projection(WGS84_PROJECTION) + wgs84_bounds = wgs84_geom.shp.bounds + if wgs84_bounds[0] <= -180 + epsilon or wgs84_bounds[2] >= 180 - epsilon: + logger.debug( + "skipping window at column %d row %d because it is out of bounds (wgs84_bounds=%s)", + tile_col, + tile_row, + wgs84_bounds, + ) + continue + if wgs84_bounds[0] < -90 and wgs84_bounds[2] > 90: + logger.debug( + "skipping window at column %d row %d because it seems to cross 0 longitude (wgs84_bounds=%s)", + tile_col, + tile_row, + wgs84_bounds, + ) + continue + + window.save() + tile_to_window[(tile_col, tile_row)] = window + + # Populate the windows. + logger.info("materialize dataset") + materialize_dataset( + ds_path, materialize_pipeline_args=SATLAS_MATERIALIZE_PIPELINE_ARGS + ) + + # Run the model, only if at least one window has some data. + completed_fnames = ds_path.glob( + f"windows/{PREDICTION_GROUP}/*/layers/{image_layer_names[0]}/completed" + ) + if len(list(completed_fnames)) == 0: + logger.info("skipping prediction since no windows seem to have data") + else: + run_model_predict(model_config_fname, ds_path) + + # Merge and upload the outputs. + if APP_IS_RASTER[application]: + src_fname = window_path / "layers" / "output" / "output" / "geotiff.tif" + + with src_fname.open("rb") as src: + with out_fname.open("wb") as dst: + shutil.copyfileobj(src, dst) + # TODO: implement valid patches and such. + raise NotImplementedError + + else: + merge_and_upload_points(projection, list(tile_to_window.values()), out_fname) + + +class PredictTaskArgs: + """Represents one prediction task among a set that shares application and paths.""" + + def __init__( + self, + projection_json: dict[str, Any], + bounds: PixelBounds, + time_range: tuple[datetime, datetime], + ): + """Create a new PredictTaskArgs. + + Args: + projection_json: serialized projection. + bounds: the bounds of this task. + time_range: the time range of this task. + """ + self.projection_json = projection_json + self.bounds = bounds + self.time_range = time_range + + def serialize(self) -> dict[str, Any]: + """Serialize the task to a dictionary. + + Returns: + JSON-encodable dictionary. + """ + return dict( + projection_json=self.projection_json, + bounds=json.dumps(self.bounds), + time_range=json.dumps( + (self.time_range[0].isoformat(), self.time_range[1].isoformat()) + ), + ) + + +def predict_multi( + application: Application, + out_path: str, + scratch_path: str, + tasks: list[PredictTaskArgs], +) -> None: + """Run multiple prediction tasks. + + Args: + application: the application. + out_path: directory to write outputs. + scratch_path: local directory to use for scratch space. + tasks: list of tasks to execute. + """ + os.makedirs(scratch_path, exist_ok=True) + for task in tasks: + with tempfile.TemporaryDirectory(dir=scratch_path) as tmp_dir: + logger.info(f"running task {task} in temporary directory {tmp_dir}") + predict_pipeline( + application=application, + projection_json=json.dumps(task.projection_json), + bounds=task.bounds, + time_range=task.time_range, + out_path=out_path, + scratch_path=os.path.join(tmp_dir, "scratch"), + ) diff --git a/rslp/satlas/publish.py b/rslp/satlas/publish.py new file mode 100644 index 00000000..12049f77 --- /dev/null +++ b/rslp/satlas/publish.py @@ -0,0 +1,195 @@ +"""Publish Satlas outputs.""" + +import json +import os +import shutil +import subprocess # nosec +import tempfile +import zipfile +from typing import Any + +import boto3 +import boto3.s3 +from upath import UPath + +from rslp.log_utils import get_logger +from rslp.satlas.bkt import make_bkt + +from .predict_pipeline import Application + +logger = get_logger(__name__) + +# Number of timesteps to re-publish. +# Smoothing for points changes all of the outputs but we only upload outputs for this +# many of the most recent timesteps. +NUM_RECOMPUTE = 6 + +# Name on Cloudflare R2 for each application. +APP_NAME_ON_R2 = { + Application.MARINE_INFRA: "marine", +} + +APP_TIPPECANOE_LAYERS = { + Application.MARINE_INFRA: "marine", +} + +SHP_EXTENSIONS = [ + ".shp", + ".dbf", + ".prj", + ".shx", +] + +BKT_TILE_PATH = "output_mosaic/" + + +def get_cloudflare_r2_bucket() -> Any: + """Returns the Cloudflare R2 bucket where outputs are published.""" + s3 = boto3.resource( + "s3", + endpoint_url=os.environ["SATLAS_R2_ENDPOINT"], + aws_access_key_id=os.environ["SATLAS_R2_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["SATLAS_R2_SECRET_ACCESS_KEY"], + ) + bucket = s3.Bucket(os.environ["SATLAS_R2_BUCKET_NAME"]) + return bucket + + +def make_shapefile_zip(fname: str) -> str: + """Create zip file of the shapefile and its supporting files. + + If filename is "x" (for x.shp and supporting files) then output is "x.shp.zip". + + Args: + fname: fname without .shp extension + + Returns: + the local filename of the resulting zip file. + """ + zip_fname = fname + ".shp.zip" + basename = os.path.basename(fname) + with zipfile.ZipFile(zip_fname, "w") as z: + for ext in SHP_EXTENSIONS: + z.write(fname + ext, arcname=basename + ext) + return zip_fname + + +def update_index(bucket: Any, prefix: str) -> None: + """Update index file on Cloudflare R2. + + The index file just has list of filenames, last modified time, and md5. + + There is one index for each application folder. + + Args: + bucket: the Cloudflare R2 bucket. + prefix: the folder's prefix in the bucket. + """ + index_lines = [] + for obj in bucket.objects.filter(Prefix=prefix): + if obj.key.endswith("/index.txt"): + continue + line = "{},{},{}".format( + obj.key, obj.last_modified, obj.e_tag.split("-")[0].replace('"', "") + ) + index_lines.append(line) + index_lines.append("") + index_data = "\n".join(index_lines) + bucket.put_object( + Body=index_data.encode(), + Key=prefix + "index.txt", + ) + + +def publish_points( + application: Application, + smoothed_path: str, + version: str, + workers: int = 32, +) -> None: + """Publish Satlas point outputs. + + The points are added to two locations: GeoJSONs are added to Cloudflare R2, while + tippecanoe is used to generate vector tiles that are uploaded to GCS for use by the + satlas.allen.ai website. + + Args: + application: the application. + smoothed_path: folder containing smoothed predictions (including + history.geojson file). + version: current model version for use to distinguish different outputs on GCS. + workers: number of worker processes. + """ + smoothed_upath = UPath(smoothed_path) + + # First upload files to R2. + bucket = get_cloudflare_r2_bucket() + with tempfile.TemporaryDirectory() as tmp_dir: + # Upload history. + logger.info("upload history") + local_hist_fname = os.path.join(tmp_dir, "history.geojson") + with (smoothed_upath / "history.geojson").open("rb") as src: + with open(local_hist_fname, "wb") as dst: + shutil.copyfileobj(src, dst) + app_name_on_r2 = APP_NAME_ON_R2[application] + bucket.upload_file(local_hist_fname, f"outputs/{app_name_on_r2}/marine.geojson") + + # Upload the latest outputs too. + available_fnames: list[UPath] = [] + for fname in smoothed_upath.iterdir(): + if fname.name == "history.geojson": + continue + available_fnames.append(fname) + available_fnames.sort(key=lambda fname: fname.name) + for fname in available_fnames[-NUM_RECOMPUTE:]: + logger.info("upload %s", str(fname)) + local_geojson_fname = os.path.join(tmp_dir, "data.geojson") + + with fname.open("rb") as src: + with open(local_geojson_fname, "wb") as dst: + shutil.copyfileobj(src, dst) + + fname_prefix = fname.name.split(".")[0] + + bucket.upload_file( + local_geojson_fname, + f"outputs/{app_name_on_r2}/{fname_prefix}.geojson", + ) + if fname == available_fnames[-1]: + bucket.upload_file( + local_geojson_fname, + f"outputs/{app_name_on_r2}/latest.geojson", + ) + + update_index(bucket, f"outputs/{app_name_on_r2}/") + + # Generate the tippecanoe tiles. + # We set tippecanoe layer via property of each feature. + with tempfile.TemporaryDirectory() as tmp_dir: + tippecanoe_layer = APP_TIPPECANOE_LAYERS[application] + with (smoothed_upath / "history.geojson").open("rb") as f: + fc = json.load(f) + for feat in fc["features"]: + feat["tippecanoe"] = {"layer": tippecanoe_layer} + local_geojson_fname = os.path.join(tmp_dir, "history.geojson") + with open(local_geojson_fname, "w") as f: + json.dump(fc, f) + + local_tile_dir = os.path.join(tmp_dir, "tiles") + logger.info("run tippecanoe on history in local tmp dir %s", local_tile_dir) + subprocess.check_call( + [ + "tippecanoe", + "-z13", + "-r1", + "--cluster-densest-as-needed", + "--no-tile-compression", + "-e", + local_tile_dir, + local_geojson_fname, + ] + ) # nosec + + tile_dst_path = f"{BKT_TILE_PATH}{version}/history/{{zoom}}/0/0.bkt" + logger.info("make bkt at %s", tile_dst_path) + make_bkt(src_dir=local_tile_dir, dst_path=tile_dst_path) diff --git a/rslp/satlas/scripts/go.mod b/rslp/satlas/scripts/go.mod new file mode 100644 index 00000000..dc30d1fb --- /dev/null +++ b/rslp/satlas/scripts/go.mod @@ -0,0 +1,12 @@ +module main + +go 1.22.5 + +require github.com/mitroadmaps/gomapinfer v0.0.0-20210917033103-4e3dcc98a112 + +require ( + github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b // indirect + github.com/dhconnelly/rtreego v1.1.0 // indirect + github.com/qedus/osmpbf v1.2.0 // indirect + google.golang.org/protobuf v1.26.0 // indirect +) diff --git a/rslp/satlas/scripts/go.sum b/rslp/satlas/scripts/go.sum new file mode 100644 index 00000000..69204293 --- /dev/null +++ b/rslp/satlas/scripts/go.sum @@ -0,0 +1,43 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= +github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= +github.com/dhconnelly/rtreego v1.1.0 h1:ejMaqN03N1s6Bdg6peGkNgBnYYSBHzcK8yhSPCB+rHE= +github.com/dhconnelly/rtreego v1.1.0/go.mod h1:SDozu0Fjy17XH1svEXJgdYq8Tah6Zjfa/4Q33Z80+KM= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/mitroadmaps/gomapinfer v0.0.0-20210917033103-4e3dcc98a112 h1:/msesy1s0b8A/xU7MSCzcDvDcrVaqvE/7Bckq1y9kAM= +github.com/mitroadmaps/gomapinfer v0.0.0-20210917033103-4e3dcc98a112/go.mod h1:60dnxKwUjhRjPfhvtjJQccZOo741GN+5WymEEc+Aa0c= +github.com/qedus/osmpbf v1.2.0 h1:yRm5ECkiUsN9sA+UN9yNnm64AVW2OYhOCb+gBa1FYCU= +github.com/qedus/osmpbf v1.2.0/go.mod h1:Cfv6JyqTZ72BjoW9FyFBQOC2DYJbL78yw+DLhBvSH+M= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= diff --git a/rslp/satlas/scripts/smooth_point_labels_viterbi.go b/rslp/satlas/scripts/smooth_point_labels_viterbi.go new file mode 100644 index 00000000..94cdd975 --- /dev/null +++ b/rslp/satlas/scripts/smooth_point_labels_viterbi.go @@ -0,0 +1,497 @@ +package main + +import ( + "encoding/json" + "flag" + "fmt" + "log" + "math" + "os" + "strconv" + "strings" + + "github.com/mitroadmaps/gomapinfer/common" +) + +const FUTURE_LABEL = "2030-01" +const TILE_SIZE = 2048 + +// Don't consider groups with fewer than this many valid timesteps. +// Note that the point doesn't need to be detected in all the timesteps, this is just +// timesteps where we have image coverage. +const MIN_VALID_TIMESTEPS = 8 + +type Tile struct { + Projection string + Column int + Row int +} + +type Point struct { + Type string `json:"type"` + Geometry struct { + Type string `json:"type"` + Coordinates [2]float64 `json:"coordinates"` + } `json:"geometry"` + label string + Properties struct { + Category *string `json:"category"` + Score *float64 `json:"score"` + Projection *string `json:"projection,omitempty"` + Column *int `json:"col,omitempty"` + Row *int `json:"row,omitempty"` + Start string `json:"start,omitempty"` + End string `json:"end,omitempty"` + } `json:"properties"` +} + +type PointData struct { + Type string `json:"type"` + Features []Point `json:"features"` + Properties struct { + ValidPatches map[string][][2]int `json:"valid_patches,omitempty"` + } `json:"properties"` +} + +type Group []Point + +func (g Group) Center() [2]int { + var sum [2]int + for _, p := range g { + sum[0] += *p.Properties.Column + sum[1] += *p.Properties.Row + } + return [2]int{ + sum[0] / len(g), + sum[1] / len(g), + } +} + +func decrementLabel(label string) string { + parts := strings.Split(label, "-") + year, _ := strconv.Atoi(parts[0]) + month, _ := strconv.Atoi(parts[1]) + month -= 1 + if month == 0 { + year -= 1 + month = 12 + } + return fmt.Sprintf("%04d-%02d", year, month) +} + +// Returns the end label for idx. +// If idx >= len(labels), we return a label far in the future. +// Otherwise, we just return labels[idx]. +func getEndLabel(labels []string, idx int) string { + if idx >= len(labels) { + return FUTURE_LABEL + } + return labels[idx] +} + +// Grid size which must be larger than the maximum expected distance threshold. +// This is currently in zoom 13 pixel coordinates so each unit is about 10 m. +const GridSize float64 = 256 + +// Factor to divide by to convert meters to point units. +const MetersPerPixel = 10 + +func main() { + labels := flag.String("labels", "", "Comma-separated list of labels") + pointFname := flag.String("fname", "", "Point filename with LABEL placeholder like in/LABEL.geojson") + outFname := flag.String("out", "", "Output filename with LABEL placeholder like out/LABEL.geojson") + histFname := flag.String("hist", "", "Merged history output filename") + distanceThreshold := flag.Float64("max_dist", 200, "Matching distance threshold in meters") + nmsDistance := flag.Float64("nms_dist", 200.0/111111, "NMS distance in degrees") + numThreads := flag.Int("threads", 32, "Number of threads") + flag.Parse() + + labelList := strings.Split(*labels, ",") + + // Read points beginning with the most recent set + // (which is likely the one that covers the most points). + var groups []Group + // Keep track of map from tiles to labels in which the tile is valid. + tileLabelValidity := make(map[Tile][]string) + for labelIdx := len(labelList) - 1; labelIdx >= 0; labelIdx-- { + label := labelList[labelIdx] + + fname := strings.ReplaceAll(*pointFname, "LABEL", label) + if _, err := os.Stat(fname); os.IsNotExist(err) { + continue + } + bytes, err := os.ReadFile(fname) + if err != nil { + panic(err) + } + + var data PointData + if err := json.Unmarshal(bytes, &data); err != nil { + panic(err) + } + + // Build grid index from the current features. + curPoints := data.Features + for idx := range curPoints { + curPoints[idx].label = label + } + gridIndexes := make(map[string]*common.GridIndex) + for idx, point := range curPoints { + projection := *point.Properties.Projection + col := float64(*point.Properties.Column) + row := float64(*point.Properties.Row) + if gridIndexes[projection] == nil { + gridIndexes[projection] = common.NewGridIndex(GridSize) + } + gridIndexes[projection].Insert(idx, common.Rectangle{ + Min: common.Point{col, row}, + Max: common.Point{col, row}, + }) + } + + log.Printf("matching %d groups with %d features at %v", len(groups), len(curPoints), label) + + // Match existing groups to the new points. + matchedIndices := make(map[int]bool) + for groupIdx, group := range groups { + projection := *group[0].Properties.Projection + center := group.Center() + + // Lookup candidate new points that could match this group using the grid index. + var indices []int + if gridIndexes[projection] != nil { + indices = gridIndexes[projection].Search(common.Rectangle{ + Min: common.Point{float64(center[0]) - GridSize, float64(center[1]) - GridSize}, + Max: common.Point{float64(center[0]) + GridSize, float64(center[1]) + GridSize}, + }) + } + + var closestIdx int = -1 + var closestDistance float64 + for _, idx := range indices { + if matchedIndices[idx] { + continue + } + + // Double check distance threshold since the index may still return + // points that are slightly outside the threshold. + // We used to check category too, but now we use the category of the + // last prediction, and just apply a distance penalty for mismatched + // category, since we noticed that sometimes there are partially + // constructed wind turbines detected as platforms but then later + // detected as turbines once construction is done, and we don't want + // that to mess up the Viterbi smoothing. Put another way, marine + // infrastructure should show up in our map even if we're not exactly + // sure about the category. + dx := center[0] - *curPoints[idx].Properties.Column + dy := center[1] - *curPoints[idx].Properties.Row + distance := math.Sqrt(float64(dx*dx + dy*dy)) + + if distance > *distanceThreshold/MetersPerPixel { + continue + } + + if *group[0].Properties.Category != *curPoints[idx].Properties.Category { + distance += *distanceThreshold / MetersPerPixel + } + + if closestIdx == -1 || distance < closestDistance { + closestIdx = idx + closestDistance = distance + } + } + + if closestIdx == -1 { + continue + } + + matchedIndices[closestIdx] = true + groups[groupIdx] = append(groups[groupIdx], curPoints[closestIdx]) + } + + // Add unmatched points in the current time as new groups. + for idx, point := range curPoints { + if matchedIndices[idx] { + continue + } + groups = append(groups, Group{point}) + } + + // Also update valid tiles/labels. + for projection, patches := range data.Properties.ValidPatches { + for _, patch := range patches { + tile := Tile{ + Projection: projection, + Column: patch[0], + Row: patch[1], + } + tileLabelValidity[tile] = append(tileLabelValidity[tile], label) + } + } + } + + // Apply non-maximal suppression over groups. + // We prefer longer groups, or if they are the same length, the group with higher + // last score. + log.Println("applying non-maximal suppression") + nmsIndex := common.NewGridIndex(*nmsDistance * 5) + for groupIdx, group := range groups { + last := group[len(group)-1] + coordinates := last.Geometry.Coordinates + nmsIndex.Insert(groupIdx, common.Point{coordinates[0], coordinates[1]}.Rectangle()) + } + var newGroups []Group + for groupIdx, group := range groups { + last := group[len(group)-1] + coordinates := last.Geometry.Coordinates + results := nmsIndex.Search(common.Point{coordinates[0], coordinates[1]}.RectangleTol(*nmsDistance)) + needsRemoval := false + for _, otherIdx := range results { + if otherIdx == groupIdx { + continue + } + other := groups[otherIdx] + otherLast := other[len(other)-1] + otherCoordinates := otherLast.Geometry.Coordinates + dx := coordinates[0] - otherCoordinates[0] + dy := coordinates[1] - otherCoordinates[1] + distance := math.Sqrt(float64(dx*dx + dy*dy)) + if distance >= *nmsDistance { + continue + } + + // It is within distance threshold, so see if group is worse than other. + if len(group) < len(other) { + needsRemoval = true + } else if len(group) == len(other) && *last.Properties.Score < *otherLast.Properties.Score { + needsRemoval = true + } + } + + if !needsRemoval { + newGroups = append(newGroups, group) + } + } + log.Printf("NMS filtered from %d to %d groups", len(groups), len(newGroups)) + groups = newGroups + + // Apply Viterbi algorithm in each group. + initialProbs := []float64{0.5, 0.5} + transitionProbs := [][]float64{ + {0.95, 0.05}, + {0.01, 0.99}, + } + emissionProbs := [][]float64{ + {0.8, 0.2}, + {0.2, 0.8}, + } + // Convert as observation history to a list of ranges when the state is non-zero. + applyViterbi := func(history []int) [][2]int { + probs := make([]float64, len(initialProbs)) + copy(probs, initialProbs) + var pointers [][]int + + // Forward pass. + for _, observation := range history { + newProbs := make([]float64, len(probs)) + curPointers := make([]int, len(probs)) + // For each new state, take max over probability resulting from different prev states. + for newState := range probs { + for prevState, prevProb := range probs { + prob := prevProb * transitionProbs[prevState][newState] * emissionProbs[newState][observation] + if prob > newProbs[newState] { + newProbs[newState] = prob + curPointers[newState] = prevState + } + } + } + probs = newProbs + pointers = append(pointers, curPointers) + } + + // Backward pass: compute max and then follow the pointers. + var finalState int + var bestProb float64 + for state, prob := range probs { + if prob < bestProb { + continue + } + bestProb = prob + finalState = state + } + reversedStates := []int{finalState} + curState := finalState + for i := len(pointers) - 1; i > 0; i-- { + curState = pointers[i][curState] + reversedStates = append(reversedStates, curState) + } + states := make([]int, len(reversedStates)) + for i := range states { + states[i] = reversedStates[len(states)-i-1] + } + + // Convert to ranges. + var ranges [][2]int + var startIdx int = -1 + for idx, state := range states { + if state == 0 && startIdx >= 0 { + // Object was active but no longer. + ranges = append(ranges, [2]int{startIdx, idx}) + startIdx = -1 + } + if state == 1 && startIdx == -1 { + startIdx = idx + } + } + // Add last range if any. + if startIdx != -1 { + ranges = append(ranges, [2]int{startIdx, len(states)}) + } + return ranges + } + + // Pass each group through Viterbi algorithm. + // This yields time ranges where a group was present in the world. + // Usually there should just be one time range associated with each group, + // but there could be multiple if there really was a gap. + // Anyway we then collect those ranges into output data. + var historyData PointData + outFeatures := make(map[string]*PointData) + log.Println("processing groups") + ch := make(chan Group) + type Rng struct { + Group Group + StartIdx int + EndIdx int + } + donech := make(chan []Rng) + for i := 0; i < *numThreads; i++ { + go func() { + var myRngs []Rng + for group := range ch { + // Create set of labels where the point is present. + labelSet := make(map[string]bool) + for _, point := range group { + labelSet[point.label] = true + } + + // Also create label set where the tile containing the point was valid. + // If tile is invalid at a label, it implies there was no satellite image data at that location/time. + validLabelSet := make(map[string]bool) + center := group.Center() + tile := Tile{ + Projection: *group[0].Properties.Projection, + Column: int(math.Floor(float64(center[0]) / TILE_SIZE)), + Row: int(math.Floor(float64(center[1]) / TILE_SIZE)), + } + for _, label := range tileLabelValidity[tile] { + validLabelSet[label] = true + } + + if len(validLabelSet) < MIN_VALID_TIMESTEPS { + continue + } + + // Now make history of observations for Viterbi algorithm. + // We only include timesteps where the tile was valid. + // We also create a map from observed timesteps to original timestep index. + var observations []int + var labelIdxMap []int + for labelIdx, label := range labelList { + if !validLabelSet[label] { + continue + } + labelIdxMap = append(labelIdxMap, labelIdx) + if labelSet[label] { + observations = append(observations, 1) + } else { + observations = append(observations, 0) + } + } + + ranges := applyViterbi(observations) + for _, rng := range ranges { + startIdx := labelIdxMap[rng[0]] + var endIdx int + if rng[1] == len(observations) { + endIdx = len(labelList) + } else { + endIdx = labelIdxMap[rng[1]] + } + myRngs = append(myRngs, Rng{ + Group: group, + StartIdx: startIdx, + EndIdx: endIdx, + }) + } + } + donech <- myRngs + }() + } + for _, group := range groups { + ch <- group + } + close(ch) + for i := 0; i < *numThreads; i++ { + curRngs := <-donech + for _, rng := range curRngs { + last := rng.Group[len(rng.Group)-1] + feat := Point{} + feat.Type = "Feature" + feat.Geometry = last.Geometry + feat.Properties.Category = last.Properties.Category + feat.Properties.Score = last.Properties.Score + + // Add the feature to the monthly outputs. + for labelIdx := rng.StartIdx; labelIdx < rng.EndIdx; labelIdx++ { + label := labelList[labelIdx] + if outFeatures[label] == nil { + outFeatures[label] = &PointData{ + Type: "FeatureCollection", + } + } + outFeatures[label].Features = append(outFeatures[label].Features, feat) + } + + // Now set start and end label for this feature correctly. + // Along with the score (computed as average of the points in the group). + // And then add it to history. + feat.Properties.Start = labelList[rng.StartIdx] + feat.Properties.End = getEndLabel(labelList, rng.EndIdx) + + var scoreSum float64 = 0 + for _, p := range rng.Group { + scoreSum += *p.Properties.Score + } + scoreAvg := scoreSum / float64(len(rng.Group)) + feat.Properties.Score = &scoreAvg + + historyData.Features = append(historyData.Features, feat) + } + } + + log.Println("writing outputs") + + if *histFname != "" { + bytes, err := json.Marshal(historyData) + if err != nil { + panic(err) + } + if err := os.WriteFile(*histFname, bytes, 0644); err != nil { + panic(err) + } + } + + if *outFname != "" { + for label, data := range outFeatures { + fname := strings.ReplaceAll(*outFname, "LABEL", label) + bytes, err := json.Marshal(data) + if err != nil { + panic(err) + } + if err := os.WriteFile(fname, bytes, 0644); err != nil { + panic(err) + } + } + } +} diff --git a/rslp/satlas/train.py b/rslp/satlas/train.py new file mode 100644 index 00000000..77229455 --- /dev/null +++ b/rslp/satlas/train.py @@ -0,0 +1,115 @@ +"""Satlas custom training code.""" + +from typing import Any + +import torch +from rslearn.train.tasks.detection import DetectionTask +from rslearn.utils import Feature + +CATEGORY_MAPPING = { + "power": "platform", +} + + +class MarineInfraTask(DetectionTask): + """Marine infrastructure detection task. + + We just add a category remapping pre-processing. + """ + + def process_inputs( + self, + raw_inputs: dict[str, torch.Tensor | list[Feature]], + metadata: dict[str, Any], + load_targets: bool = True, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Processes the data into targets. + + Args: + raw_inputs: raster or vector data to process + metadata: metadata about the patch being read + load_targets: whether to load the targets or only inputs + + Returns: + tuple (input_dict, target_dict) containing the processed inputs and targets + that are compatible with both metrics and loss functions + """ + if not load_targets: + return {}, {} + + for feat in raw_inputs["targets"]: + if self.property_name not in feat.properties: + continue + category = feat.properties[self.property_name] + if category not in CATEGORY_MAPPING: + continue + feat.properties[self.property_name] = CATEGORY_MAPPING[category] + + return super().process_inputs(raw_inputs, metadata, load_targets) + + +class TeePipe(torch.nn.Module): + """TeePipe passes different channels of the input image to different backbones. + + The features from the different backbones are then concatenated and returned. + """ + + def __init__( + self, + encoders: list[torch.nn.Module], + channels: list[list[int]], + ): + """Create a new TeePipe. + + Args: + encoders: the encoders to apply. + channels: the subset of channels that each encoder should input. For + example, if the input is ABCDEF and first encoder should see ABC while + second should see DEF, then the list should be [[0, 1, 2], [3, 4, 5]]. + """ + super().__init__() + self.encoders = torch.nn.ModuleList(encoders) + self.channels = channels + + def get_backbone_channels(self) -> list: + """Returns the output channels of this model when used as a backbone. + + Returns: + the output channels of the backbone as a list of (downsample_factor, depth) + tuples. + """ + # We assume that each encoder outputs features at matching resolutions. + out_channels = self.encoders[0].get_backbone_channels() + + for encoder in self.encoders[1:]: + cur_channels = encoder.get_backbone_channels() + for idx, (downsample_factor, depth) in enumerate(cur_channels): + if out_channels[idx][0] != downsample_factor: + raise ValueError( + "encoders have mis-matching resolutions of output feature maps" + ) + out_channels[idx][1] += depth + + return out_channels + + def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]: + """Compute features. + + Args: + inputs: input dicts that must include "image" key containing the images to + process. + """ + # index in feature map -> encoder index -> feature map + all_features: list[list[torch.Tensor]] = [ + [] for _ in self.get_backbone_channels() + ] + + for encoder, cur_channels in zip(self.encoders, self.channels): + cur_features = encoder( + [{"image": inp["image"][cur_channels, :, :]} for inp in inputs] + ) + for idx, feat_map in enumerate(cur_features): + all_features[idx].append(feat_map) + + # Final feature map should concatenate at each scale. + return [torch.cat(feat_map_list, dim=1) for feat_map_list in all_features] diff --git a/rslp/satlas/write_jobs.py b/rslp/satlas/write_jobs.py new file mode 100644 index 00000000..6ea8bd69 --- /dev/null +++ b/rslp/satlas/write_jobs.py @@ -0,0 +1,307 @@ +"""Launch Satlas prediction jobs on Beaker.""" + +import json +import random +from collections.abc import Generator +from datetime import datetime, timedelta, timezone + +import shapely +import tqdm +from rasterio.crs import CRS +from rslearn.const import WGS84_PROJECTION +from rslearn.utils.geometry import PixelBounds, Projection, STGeometry +from rslearn.utils.get_utm_ups_crs import get_proj_bounds +from upath import UPath + +import rslp.common.worker +from rslp.log_utils import get_logger + +from .predict_pipeline import Application, PredictTaskArgs, get_output_fname + +logger = get_logger(__name__) + +TILE_SIZE = 32768 +RESOLUTION = 10 + +# Days to add before a provided date. +DEFAULT_DAYS_BEFORE = 120 + +# Days to add after a provided date. +DEFAULT_DAYS_AFTER = 90 + + +class Task: + """Represents a task that processes one tile at one point in time.""" + + def __init__( + self, + application: Application, + projection: Projection, + bounds: PixelBounds, + time_range: tuple[datetime, datetime], + out_path: str, + ) -> None: + """Create a new Task. + + Args: + application: the application to run + projection: the projection of the tile + bounds: the bounds of the tile + time_range: the time range to process + out_path: where to write outputs + """ + self.application = application + self.projection = projection + self.bounds = bounds + self.time_range = time_range + self.out_path = out_path + + def get_output_fname(self) -> UPath: + """Get the output filename that will be used for this task.""" + # The filename format is defined by get_output_fname in predict_pipeline.py. + return get_output_fname( + self.application, self.out_path, self.projection, self.bounds + ) + + +def enumerate_tiles_in_zone(utm_zone: CRS) -> Generator[tuple[int, int], None, None]: + """List all of the tiles in the zone where outputs should be computed. + + The tiles are all TILE_SIZE x TILE_SIZE so only the column/row of the tile along + that grid are returned. + + Args: + utm_zone: the CRS which must correspond to a UTM EPSG. + + Returns: + generator of (column, row) of the tiles that are needed. + """ + # We use get_proj_bounds to get the bounds of the UTM zone in CRS units. + # We then convert to pixel units in order to determine the tiles that are needed. + crs_bbox = STGeometry( + Projection(utm_zone, 1, 1), + shapely.box(*get_proj_bounds(utm_zone)), + None, + ) + projection = Projection(utm_zone, RESOLUTION, -RESOLUTION) + pixel_bbox = crs_bbox.to_projection(projection) + + # Convert the resulting shape to integer bbox. + zone_bounds = tuple(int(value) for value in pixel_bbox.shp.bounds) + + for col in range(zone_bounds[0] // TILE_SIZE, zone_bounds[2] // TILE_SIZE + 1): + for row in range(zone_bounds[1] // TILE_SIZE, zone_bounds[3] // TILE_SIZE + 1): + yield (col, row) + + +def get_jobs( + application: Application, + time_range: tuple[datetime, datetime], + out_path: str, + epsg_code: int | None = None, + wgs84_bounds: tuple[float, float, float, float] | None = None, + batch_size: int = 1, + count: int | None = None, +) -> list[list[str]]: + """Get batches of tasks for Satlas prediction. + + Tasks where outputs have already been computed are excluded. + + Args: + application: which application to run. + time_range: the time range to run within. Must have timezone. + out_path: the output path. It should be specific to the time range. + epsg_code: limit tasks to this UTM zone (specified by its EPSG code), default + run in all UTM zones. + wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. + batch_size: how many tasks to run in each batch. + count: limit to this many tasks. + + Returns: + the list of worker tasks where each worker task + """ + # Generate tasks. + if epsg_code: + utm_zones = [CRS.from_epsg(epsg_code)] + else: + utm_zones = [] + for epsg_code in range(32601, 32661): + utm_zones.append(CRS.from_epsg(epsg_code)) + for epsg_code in range(32701, 32761): + utm_zones.append(CRS.from_epsg(epsg_code)) + + tasks: list[Task] = [] + for utm_zone in tqdm.tqdm(utm_zones, desc="Enumerating tasks across UTM zones"): + projection = Projection(utm_zone, RESOLUTION, -RESOLUTION) + + # If the user provided WGS84 bounds, then we convert it to pixel coordinates so + # we can check each tile easily. + user_bounds_in_proj: PixelBounds | None = None + if wgs84_bounds is not None: + dst_geom = STGeometry( + WGS84_PROJECTION, shapely.box(*wgs84_bounds), None + ).to_projection(projection) + user_bounds_in_proj = ( + int(dst_geom.shp.bounds[0]), + int(dst_geom.shp.bounds[1]), + int(dst_geom.shp.bounds[2]), + int(dst_geom.shp.bounds[3]), + ) + + for col, row in enumerate_tiles_in_zone(utm_zone): + if user_bounds_in_proj is not None: + # Check if this task intersects the bounds specified by the user. + if (col + 1) * TILE_SIZE < user_bounds_in_proj[0]: + continue + if col * TILE_SIZE >= user_bounds_in_proj[2]: + continue + if (row + 1) * TILE_SIZE < user_bounds_in_proj[1]: + continue + if row * TILE_SIZE >= user_bounds_in_proj[3]: + continue + + tasks.append( + Task( + application=application, + projection=projection, + bounds=( + col * TILE_SIZE, + row * TILE_SIZE, + (col + 1) * TILE_SIZE, + (row + 1) * TILE_SIZE, + ), + time_range=time_range, + out_path=out_path, + ) + ) + + logger.info("Got %d total tasks", len(tasks)) + + # Remove tasks where outputs are already computed. + existing_output_fnames = {out_fname.name for out_fname in UPath(out_path).iterdir()} + tasks = [ + task + for task in tasks + if task.get_output_fname().name not in existing_output_fnames + ] + logger.info("Got %d tasks that are uncompleted", len(tasks)) + + # Sample tasks down to user-provided count (max # tasks to run), if provided. + if count is not None and len(tasks) > count: + tasks = random.sample(tasks, count) + logger.info("Randomly sampled %d tasks", len(tasks)) + + # Convert tasks to jobs for use with rslp.common.worker. + # This is what will be written to the Pub/Sub topic. + jobs = [] + for i in range(0, len(tasks), batch_size): + cur_tasks = tasks[i : i + batch_size] + + # Get list of PredictTaskArgs that we can serialize. + # These just specify the projection, time range, and bounds. + predict_tasks = [] + for task in cur_tasks: + predict_tasks.append( + PredictTaskArgs( + projection_json=task.projection.serialize(), + bounds=task.bounds, + time_range=task.time_range, + ) + ) + + cur_args = [ + application.value.upper(), + out_path, + "/tmp/scratch/", + json.dumps([predict_task.serialize() for predict_task in predict_tasks]), + ] + jobs.append(cur_args) + + return jobs + + +def write_jobs( + application: Application, + time_range: tuple[datetime, datetime], + out_path: str, + project_id: str, + topic_id: str, + epsg_code: int | None = None, + wgs84_bounds: tuple[float, float, float, float] | None = None, + batch_size: int = 1, + count: int | None = None, +) -> None: + """Write jobs for the specified application and time range. + + Args: + application: which application to run. + time_range: the time range to run within. Must have timezone. + out_path: the output path. It should be specific to the time range. + project_id: project containing the Pub/Sub topic. + topic_id: the Pub/Sub topic to write the jobs to. + epsg_code: limit tasks to this UTM zone (specified by its EPSG code), default + run in all UTM zones. + wgs84_bounds: limit tasks to ones that intersect these WGS84 bounds. + batch_size: how many tasks to run in each batch. + count: limit to this many tasks. + """ + jobs = get_jobs( + application, + time_range, + out_path, + epsg_code=epsg_code, + wgs84_bounds=wgs84_bounds, + batch_size=batch_size, + count=count, + ) + rslp.common.worker.write_jobs(project_id, topic_id, "satlas", "predict_multi", jobs) + + +def write_jobs_for_year_months( + year_months: list[tuple[int, int]], + application: Application, + out_path: str, + project_id: str, + topic_id: str, + batch_size: int = 1, + days_before: int = DEFAULT_DAYS_BEFORE, + days_after: int = DEFAULT_DAYS_AFTER, + count: int | None = None, +) -> None: + """Write Satlas prediction jobs for the given year and month. + + Args: + year_months: list of year-month pairs. + application: the application to run. + out_path: the output path with year and month placeholders. + project_id: project containing the Pub/Sub topic. + topic_id: the Pub/Sub topic to write the jobs to. + worker_params: the worker parameters. + batch_size: the batch size. + days_before: how much to pad windows before the year/month. + days_after: how much to pad windows after the year/month. + count: limit each year-month to this many tasks. + """ + jobs = [] + for year, month in year_months: + ts = datetime(year, month, 1, tzinfo=timezone.utc) + time_range = ( + ts - timedelta(days=days_before), + ts + timedelta(days=days_after), + ) + cur_out_path = out_path.format(year=year, month=month) + logger.info( + f"collecting jobs for year={year}, month={month}, time_range={time_range}, out_path={cur_out_path}" + ) + cur_jobs = get_jobs( + application=application, + time_range=time_range, + out_path=cur_out_path, + batch_size=batch_size, + count=count, + ) + logger.info("got %d jobs for %04d-%02d", len(cur_jobs), year, month) + jobs.extend(cur_jobs) + + logger.info("got a total of %d jobs across year-months", len(jobs)) + rslp.common.worker.write_jobs(project_id, topic_id, "satlas", "predict_multi", jobs) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..cb1fc115 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import logging + +import pytest + +from rslp.utils.mp import init_mp + +logging.basicConfig() + + +@pytest.fixture(scope="session", autouse=True) +def always_init_mp() -> None: + init_mp() diff --git a/tests/integration/common/__init__.py b/tests/integration/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/common/test_worker.py b/tests/integration/common/test_worker.py new file mode 100644 index 00000000..912cb1f4 --- /dev/null +++ b/tests/integration/common/test_worker.py @@ -0,0 +1,50 @@ +"""Test common.worker functionality.""" + +import os +import pathlib +import time + +from rslp.common.worker import worker_pipeline, write_jobs + + +def test_idle_timeout() -> None: + """Verify that worker exits within about the idle timeout.""" + idle_timeout = 2 + start_time = time.time() + worker_pipeline( + project_id=os.environ["TEST_PUBSUB_PROJECT"], + subscription_id=os.environ["TEST_PUBSUB_SUBSCRIPTION"], + idle_timeout=idle_timeout, + ) + end_time = time.time() + elapsed = end_time - start_time + # Use 3 here in case worker decides to sleep twice (and then some extra time + # elapses). + assert elapsed >= idle_timeout and elapsed <= 3 * idle_timeout + + +def test_single_task(tmp_path: pathlib.Path) -> None: + """Check that worker can do one task.""" + # Use a task that writes to this filename. + dst_fname = tmp_path / "test_file" + # Write the task to the test topic. + job_args = [ + str(dst_fname), + # The contents to write. + "abc", + ] + write_jobs( + project_id=os.environ["TEST_PUBSUB_PROJECT"], + topic_id=os.environ["TEST_PUBSUB_TOPIC"], + rslp_project="common", + rslp_workflow="write_file", + args_list=[job_args], + ) + # Run the worker. + worker_pipeline( + project_id=os.environ["TEST_PUBSUB_PROJECT"], + subscription_id=os.environ["TEST_PUBSUB_SUBSCRIPTION"], + idle_timeout=1, + ) + # Verify that the file was created. + assert dst_fname.exists() diff --git a/tests/integration/forest_loss_driver/inference/test_model_predict.py b/tests/integration/forest_loss_driver/inference/test_model_predict.py index f78b9c8f..c65edfbc 100644 --- a/tests/integration/forest_loss_driver/inference/test_model_predict.py +++ b/tests/integration/forest_loss_driver/inference/test_model_predict.py @@ -95,7 +95,7 @@ def test_forest_loss_driver_model_predict( output_json = json.load(f) # TODO: Ideally we would have a pydantic model for this output perhaps that we could subclass from rslearn? # Check properties except probs - tol = 0.01 + tol = 0.1 assert output_json["type"] == expected_output_json["type"] assert output_json["properties"] == expected_output_json["properties"] assert len(output_json["features"]) == len(expected_output_json["features"]) diff --git a/tests/integration/satlas/__init__.py b/tests/integration/satlas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/satlas/test_bkt.py b/tests/integration/satlas/test_bkt.py new file mode 100644 index 00000000..97123a2a --- /dev/null +++ b/tests/integration/satlas/test_bkt.py @@ -0,0 +1,54 @@ +"""Test bkt file operations.""" + +import os +import pathlib + +import pytest + +from rslp.satlas.bkt import ( + BktDownloadRequest, + download_from_bkt, + get_bigtable, + make_bkt, +) + +RUNNING_IN_CI = os.environ.get("CI", "false").lower() == "true" + + +@pytest.mark.skipif(RUNNING_IN_CI, reason="Skipping in CI environment") +def test_make_and_download_bkt(tmp_path: pathlib.Path) -> None: + """Test making and downloading a bkt file. + + We create two files and make a bkt from them. + + Then we try to use download_from_bkt to download both of those files and verify + that it is returned correctly. + """ + + # make_bkt expects a directory structure zoom/col/row. + zoom = "1" + data_col0_row0 = b"bkt1" + data_col0_row1 = b"bkt2" + bkt_fname = "tests/test_make_and_download_bkt" + (tmp_path / zoom / "0").mkdir(parents=True) + with (tmp_path / zoom / "0" / "0").open("wb") as f: + f.write(data_col0_row0) + with (tmp_path / zoom / "0" / "1").open("wb") as f: + f.write(data_col0_row1) + make_bkt(str(tmp_path), bkt_fname) + + # Now try to download both of those files. + # We use the metadata to store the expected contents. + # We also read an extra tile that shouldn't exist. + download_requests = [ + BktDownloadRequest(bkt_fname, 0, 0, metadata=data_col0_row0), + BktDownloadRequest(bkt_fname, 0, 1, metadata=data_col0_row1), + # This one should not exist. + BktDownloadRequest(bkt_fname, 1, 1), + ] + bkt_files_table = get_bigtable() + # Call download_from_bkt and populate into list so we can check the length. + results = list(download_from_bkt(bkt_files_table, download_requests)) + assert len(results) == 2 + for expected, actual in results: + assert expected == actual diff --git a/tests/integration/satlas/test_data_sources.py b/tests/integration/satlas/test_data_sources.py new file mode 100644 index 00000000..fe97f3c6 --- /dev/null +++ b/tests/integration/satlas/test_data_sources.py @@ -0,0 +1,111 @@ +"""Test Satlas data_sources.py.""" + +import pathlib +from datetime import datetime, timedelta, timezone + +import shapely +from rslearn.config import ( + LayerType, + QueryConfig, + RasterLayerConfig, + SpaceMode, +) +from rslearn.const import WGS84_PROJECTION +from rslearn.data_sources import DataSource +from rslearn.data_sources.gcp_public_data import Sentinel2 as Sentinel2 +from rslearn.data_sources.planetary_computer import Sentinel1 +from rslearn.data_sources.planetary_computer import Sentinel2 as AzureSentinel2 +from rslearn.utils.geometry import STGeometry +from upath import UPath + +from rslp.satlas.data_sources import ( + MonthlyAzureSentinel2, + MonthlySentinel1, + MonthlySentinel2, +) + +PERIOD_DAYS = 30 + + +class TestGetItems: + """Test the get_items method in per-period data source.""" + + def apply_test(self, data_source: DataSource) -> None: + """Test that the data source successfully returns per-period mosaics. + + We apply it on a bbox of Seattle for three-month period with period equal to one + month. + + Args: + data_source: the data source to test. + """ + # Create a 0.002x0.002 degree bbox near Seattle for three-month time range. + seattle_point = (-122.33, 47.61) + shp = shapely.box( + seattle_point[0] - 0.001, + seattle_point[1] - 0.001, + seattle_point[0] + 0.001, + seattle_point[1] + 0.001, + ) + time_range = ( + datetime(2024, 4, 1, tzinfo=timezone.utc), + datetime(2024, 7, 1, tzinfo=timezone.utc), + ) + geometry = STGeometry(WGS84_PROJECTION, shp, time_range) + + # Look for 2 per-month mosaics. + # The first month in the three-month time range should not yield a mosaic since + # it is not needed (it would only be used if the more recent months do not have + # any scenes, but that is not the case here). + query_config = QueryConfig( + space_mode=SpaceMode.MOSAIC, + max_matches=2, + ) + groups = data_source.get_items([geometry], query_config)[0] + + # We expect to get two groups, and each one should be in a different period. + # The groups should be ordered from most recent to least recent. + # There should not be any group for the first period in the three-month time + # range since we expect there to be a mosaic available for the more recent + # periods and max_matches=2. + expected_time_ranges = [ + (time_range[1] - timedelta(days=PERIOD_DAYS), time_range[1]), + ( + time_range[1] - timedelta(days=PERIOD_DAYS * 2), + time_range[1] - timedelta(days=PERIOD_DAYS), + ), + ] + assert len(groups) == len(expected_time_ranges) + for expected_time_range, group in zip(expected_time_ranges, groups): + assert len(group) > 0 + for item in group: + item_ts = item.geometry.time_range[0] + assert expected_time_range[0] <= item_ts <= expected_time_range[1] + + def test_sentinel1(self) -> None: + """Run apply_test with MonthlySentinel1.""" + sentinel1 = MonthlySentinel1( + sentinel1=Sentinel1(["vv"]), + period_days=PERIOD_DAYS, + ) + self.apply_test(sentinel1) + + def test_sentinel2(self, tmp_path: pathlib.Path) -> None: + """Run apply_test with MonthlySentinel2.""" + sentinel2 = MonthlySentinel2( + sentinel2=Sentinel2( + RasterLayerConfig(LayerType.RASTER, []), + index_cache_dir=UPath(tmp_path), + use_rtree_index=False, + ), + period_days=PERIOD_DAYS, + ) + self.apply_test(sentinel2) + + def test_azure_sentinel2(self) -> None: + """Run apply_test with MonthlyAzureSentinel2.""" + sentinel2 = MonthlyAzureSentinel2( + sentinel2=AzureSentinel2(["B04"]), + period_days=PERIOD_DAYS, + ) + self.apply_test(sentinel2) diff --git a/tests/integration_slow/satlas/test_predict_pipeline.py b/tests/integration_slow/satlas/test_predict_pipeline.py new file mode 100644 index 00000000..ef2a84c3 --- /dev/null +++ b/tests/integration_slow/satlas/test_predict_pipeline.py @@ -0,0 +1,73 @@ +"""Test the Satlas prediction pipeline.""" + +import json +import pathlib +from datetime import datetime, timezone + +import shapely +from rslearn.const import WGS84_PROJECTION +from rslearn.utils.geometry import STGeometry +from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection + +from rslp.satlas.predict_pipeline import ( + PATCH_SIZE, + Application, + get_output_fname, + predict_pipeline, +) + + +def test_predict_pipeline_point(tmp_path: pathlib.Path) -> None: + # Test the prediction pipeline runs correctly for a point detection task. + # Specifically, we apply the marine infrastructure model on a window covering a + # small portion of the Hornsea 2 Offshore Wind Farm, and verify that the resulting + # detections include at least one turbine. + + # These are the coordinates of the known wind turbine. + src_geom = STGeometry(WGS84_PROJECTION, shapely.Point(1.859, 53.91), None) + + # We get the corresponding UTM window. + # The window's bounds must all be multiples of PATCH_SIZE. + projection = get_utm_ups_projection(src_geom.shp.x, src_geom.shp.y, 10, -10) + dst_geom = src_geom.to_projection(projection) + start = ( + int(dst_geom.shp.x) // PATCH_SIZE * PATCH_SIZE, + int(dst_geom.shp.y) // PATCH_SIZE * PATCH_SIZE, + ) + bounds = (start[0], start[1], start[0] + PATCH_SIZE, start[1] + PATCH_SIZE) + + # The wind farm existed since 2019 so this time range will work. + time_range = ( + datetime(2024, 1, 1, tzinfo=timezone.utc), + datetime(2024, 8, 1, tzinfo=timezone.utc), + ) + + # Output path will contain outputs, while scratch path is used as a working + # directory. Specifically, the scratch path will be populated with an rslearn + # dataset containing one window matching the projection/bounds that we provide. + out_path = tmp_path / "out" + scratch_path = tmp_path / "scratch" + out_path.mkdir() + + # Apply the pipeline. It will ingest data and apply the model. + # We disable rtree index so that it doesn't need an hour to create it. + predict_pipeline( + application=Application.MARINE_INFRA, + projection_json=json.dumps(projection.serialize()), + bounds=bounds, + time_range=time_range, + out_path=str(out_path), + scratch_path=str(scratch_path), + use_rtree_index=False, + ) + + # Now we verify that the output includes at least one turbine. + out_fname = get_output_fname( + Application.MARINE_INFRA, str(out_path), projection, bounds + ) + with out_fname.open() as f: + fc = json.load(f) + turbine_features = [ + feat for feat in fc["features"] if feat["properties"]["category"] == "turbine" + ] + assert len(turbine_features) > 0 diff --git a/tests/unit/satlas/__init__.py b/tests/unit/satlas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/satlas/test_postprocess.py b/tests/unit/satlas/test_postprocess.py new file mode 100644 index 00000000..b7dfe94a --- /dev/null +++ b/tests/unit/satlas/test_postprocess.py @@ -0,0 +1,370 @@ +"""Test Satlas post-processing.""" + +import json +import pathlib + +import shapely +from rasterio.crs import CRS +from rslearn.const import WGS84_PROJECTION +from rslearn.utils.geometry import Projection, STGeometry + +from rslp.satlas.postprocess import merge_points, smooth_points +from rslp.satlas.predict_pipeline import PATCH_SIZE, Application + + +class TestMergePoints: + """Test the merge_points function.""" + + def make_task_output( + self, + fname: pathlib.Path, + projection: Projection, + coords: list[tuple[float, float]], + valid_patches: list[tuple[int, int]], + ) -> None: + """Make a JSON matching those produced by the Satlas predict_pipeline. + + Args: + fname: the filename to write to. + projection: the projection of the prediction task. The task writes the + GeoJSON in pixel coordinates under that projection. + coords: list of point (col, row) coordinates to include. + valid_patches: list of (col, row) patches to include. These are in tiles of + PATCH_SIZE (see rslp.satlas.predict_pipeline). + """ + # Convert features to GeoJSON. + features = [] + for col, row in coords: + features.append( + { + "type": "Feature", + "properties": { + "score": 1, + "category": "placeholder", + }, + "geometry": { + "type": "Point", + "coordinates": [col, row], + }, + } + ) + + # Make the FeatureCollection. + fc = { + "type": "FeatureCollection", + "features": features, + "properties": projection.serialize(), + } + # Add the valid patches. It is a dict from CRS to tile list. + fc["properties"]["valid_patches"] = { + str(projection.crs): valid_patches, + } + + fname.parent.mkdir(parents=True, exist_ok=True) + with open(fname, "w") as f: + json.dump(fc, f) + + def test_two_crs_merged(self, tmp_path: pathlib.Path) -> None: + """Verify that when merging across two CRS it is successful.""" + proj32601 = Projection(CRS.from_epsg(32601), 10, -10) + proj32602 = Projection(CRS.from_epsg(32602), 10, -10) + + predict_path = tmp_path / "predict" + merged_path = tmp_path / "merged" + + # File 1 is in 32601 and contains one feature. + self.make_task_output(predict_path / "1.geojson", proj32601, [(0, 0)], [(0, 0)]) + # File 2 is also in 32601 and contains one different feature and patch. + self.make_task_output( + predict_path / "2.geojson", proj32601, [(2048, 2048)], [(1, 1)] + ) + # File 3 is in 32602 and contains a third feature. + self.make_task_output(predict_path / "3.geojson", proj32602, [(0, 0)], [(0, 0)]) + + # Run the merging. + merge_points( + Application.MARINE_INFRA, + # Use arbitrary YYYY-MM label. + "1234-56", + str(predict_path), + str(merged_path), + ) + + # Verify the output. + merged_fname = merged_path / "1234-56.geojson" + with merged_fname.open() as f: + fc = json.load(f) + # And the valid patches should be merged, with one in 32601 and two in 32602. + valid_patches = fc["properties"]["valid_patches"] + assert len(valid_patches) == 2 + patches32601 = valid_patches[str(proj32601.crs)] + patches32601.sort() + assert patches32601 == [[0, 0], [1, 1]] + patches32602 = valid_patches[str(proj32602.crs)] + patches32602.sort() + assert patches32602 == [[0, 0]] + + features = fc["features"] + assert len(features) == 3 + # Make sure crs property is set correctly. + feature_projections = [feat["properties"]["projection"] for feat in features] + feature_projections.sort() + assert feature_projections == [ + str(proj32601.crs), + str(proj32601.crs), + str(proj32602.crs), + ] + + +class TestSmoothPoints: + """Test the smooth_points function.""" + + # Minimum number of timesteps for a point to be considered. + # This is used in smooth_point_labels_viterbi.go to discard regions that are only + # covered by satellite imagery for a couple months over multi-year time range. + MIN_VALID_TIMESTEPS = 8 + + def make_merge_output( + self, + fname: pathlib.Path, + geometries: list[STGeometry], + additional_valid_patches: list[tuple[Projection, int, int]] = [], + scores: list[float] | None = None, + ) -> None: + """Make a GeoJSON matching those produced by merge_points. + + This is the input to smooth_points. + + Args: + fname: the filename to write to. + geometries: list of STGeometry to include. The valid patches will be set + automatically to include all of these geometries. + additional_valid_patches: additional valid patches (besides those covered + by the geometries). + scores: optional list of scores of each geometry. If set, it should be the + same length as geometries. + """ + # Convert features to GeoJSON. + features = [] + valid_patches: dict[str, set[tuple[int, int]]] = {} + for geometry_idx, geometry in enumerate(geometries): + wgs84_geometry = geometry.to_projection(WGS84_PROJECTION) + crs_str = str(geometry.projection.crs) + score = scores[geometry_idx] if scores else 1 + features.append( + { + "type": "Feature", + "properties": { + "category": "placeholder", + "score": score, + "projection": crs_str, + "col": int(geometry.shp.x), + "row": int(geometry.shp.y), + }, + "geometry": { + "type": "Point", + "coordinates": [wgs84_geometry.shp.x, wgs84_geometry.shp.y], + }, + } + ) + if crs_str not in valid_patches: + valid_patches[crs_str] = set() + valid_patches[crs_str].add( + (int(geometry.shp.x) // PATCH_SIZE, int(geometry.shp.y) // PATCH_SIZE) + ) + + for projection, col, row in additional_valid_patches: + crs_str = str(projection.crs) + if crs_str not in valid_patches: + valid_patches[crs_str] = set() + valid_patches[crs_str].add((col, row)) + + # Make and write the FeatureCollection. + fc = { + "type": "FeatureCollection", + "features": features, + "properties": { + "valid_patches": { + crs_str: list(patch_set) + for crs_str, patch_set in valid_patches.items() + }, + }, + } + fname.parent.mkdir(parents=True, exist_ok=True) + with open(fname, "w") as f: + json.dump(fc, f) + + def test_smooth_nms(self, tmp_path: pathlib.Path) -> None: + """Verify that smoothing deletes redundant points across projections.""" + # (-174, 0) is on the border between EPSG:32601 and EPSG:32602. + # So we add it in both projections and make sure smooth_points deletes it. + wgs84_geom = STGeometry(WGS84_PROJECTION, shapely.Point(-174, 0), None) + proj32601 = Projection(CRS.from_epsg(32601), 10, -10) + proj32602 = Projection(CRS.from_epsg(32602), 10, -10) + geom32601 = wgs84_geom.to_projection(proj32601) + geom32602 = wgs84_geom.to_projection(proj32602) + + merged_path = tmp_path / "merged" + smoothed_path = tmp_path / "smooth" + + # Create the input files. + # We need MIN_VALID_TIMESTEPS files for any points to be produced. + labels = [f"0000-0{timestep}" for timestep in range(self.MIN_VALID_TIMESTEPS)] + for label in labels: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geom32601, geom32602], + scores=[0.5, 0.6], + ) + + # Run the smoothing. + smooth_points( + Application.MARINE_INFRA, + labels[-1], + str(merged_path), + str(smoothed_path), + ) + + # Verify the output. + smoothed_fname = smoothed_path / f"{labels[-1]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 1 + + def test_discard_single_timestep_positive(self, tmp_path: pathlib.Path) -> None: + """Ignore a point of it is only detected in one timestep.""" + projection = Projection(CRS.from_epsg(32601), 10, -10) + geometry1 = STGeometry(projection, shapely.Point(0, 0), None) + geometry2 = STGeometry( + projection, shapely.Point(PATCH_SIZE // 2, PATCH_SIZE // 2), None + ) + + merged_path = tmp_path / "merged" + smoothed_path = tmp_path / "smooth" + + # Create the input files. + # We need MIN_VALID_TIMESTEPS files where the patch is valid, otherwise it will + # be ignored. + # We use two points here since timesteps without any points yield no output. + # It also makes it easier to ensure that patch is marked valid. + labels = [f"0000-0{timestep}" for timestep in range(self.MIN_VALID_TIMESTEPS)] + for timestep, label in enumerate(labels): + if timestep == 4: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry1, geometry2], + ) + else: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry2], + ) + + # Run the smoothing. + smooth_points( + Application.MARINE_INFRA, + labels[-1], + str(merged_path), + str(smoothed_path), + ) + + # Verify the output. + smoothed_fname = smoothed_path / f"{labels[-1]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 1 + + def test_invalid_timesteps_ignored(self, tmp_path: pathlib.Path) -> None: + """Ensure a point keeps being predicted if its patch is not observed.""" + projection = Projection(CRS.from_epsg(32601), 10, -10) + geometry = STGeometry(projection, shapely.Point(0, 0), None) + + merged_path = tmp_path / "merged" + smoothed_path = tmp_path / "smooth" + + # Create the input files. + # In the first timesteps, we detect the point. + labels = [ + f"0000-0{timestep}" for timestep in range(self.MIN_VALID_TIMESTEPS * 2) + ] + for label in labels[0 : self.MIN_VALID_TIMESTEPS]: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry], + ) + # In the remaining timesteps, we leave the patch invalid. + for label in labels[self.MIN_VALID_TIMESTEPS :]: + self.make_merge_output( + merged_path / f"{label}.geojson", + [], + ) + + # Run the smoothing. + smooth_points( + Application.MARINE_INFRA, + labels[-1], + str(merged_path), + str(smoothed_path), + ) + + # Verify the output. + smoothed_fname = smoothed_path / f"{labels[-1]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 1 + + def test_point_removed(self, tmp_path: pathlib.Path) -> None: + """Verify that a point will be removed with enough negative observations.""" + projection = Projection(CRS.from_epsg(32601), 10, -10) + geometry1 = STGeometry(projection, shapely.Point(0, 0), None) + # As with test_discard_single_timestep_positive, we use a second point so it is + # easier to ensure the patch is valid and so that outputs are produced at all + # timesteps. Here, we will observe geometry2 at every timestep but cut + # geometry1 off halfway through. + geometry2 = STGeometry( + projection, shapely.Point(PATCH_SIZE // 2, PATCH_SIZE // 2), None + ) + + merged_path = tmp_path / "merged" + smoothed_path = tmp_path / "smooth" + + # We will have 30 positive observations and then 30 negative ones. + # This ensures enough because there is very low transition probability from + # positive to negative (since it is unlikely that e.g. a wind turbine would be + # torn down). + num_observations = 30 + + # Create the input files. + # In the first timesteps, we detect both points. + labels = [f"0000-0{timestep}" for timestep in range(num_observations * 2)] + for label in labels[0:num_observations]: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry1, geometry2], + ) + # In the remaining timesteps, we only detect geometry2. + for label in labels[num_observations:]: + self.make_merge_output( + merged_path / f"{label}.geojson", + [geometry2], + ) + + # Run the smoothing. + smooth_points( + Application.MARINE_INFRA, + labels[-1], + str(merged_path), + str(smoothed_path), + ) + + # At the first timestep, we should see both points. + smoothed_fname = smoothed_path / f"{labels[0]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 2 + + # At the last timestep, we should only see geometry2. + smoothed_fname = smoothed_path / f"{labels[-1]}.geojson" + with smoothed_fname.open() as f: + fc = json.load(f) + assert len(fc["features"]) == 1