Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
yawenzzzz committed Nov 26, 2024
1 parent 2a0ca68 commit 936aa91
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 27 deletions.
2 changes: 1 addition & 1 deletion data/landsat_vessels/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ trainer:
class_path: rslp.utils.nms.NMSDistanceMerger
init_args:
grid_size: 64
distance_threshold: 15
distance_threshold: 10
property_name: "category" # same as task.property_name
class_agnostic: false
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
Expand Down
2 changes: 1 addition & 1 deletion docs/landsat_vessels/model_summary.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Note: The evaluation metrics are reported for the two-stage model (detector + cl

| Date | Version | Precision | Recall | F1-Score |
|------------|---------|-----------|--------|----------|
| 2024-11-15 | 0.0.1 | 0.77 | 0.60 | 0.67 |
| 2024-11-15 | 0.0.1 | 0.72 | 0.53 | 0.61 |
| YYYY-MM-DD | TBD | TBD | TBD | TBD |

## Offline Scenario Checks
Expand Down
2 changes: 1 addition & 1 deletion docs/landsat_vessels/train_eval.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ The whole pipeline is evaluated with two approaches.

This will launch multiple beaker jobs. Each job will evaluate the model on one window and save the results in the `jsons` directory.

2. Compute the evaluation metrics:
2. Compute the metrics:

```python
python rslp/landsat_vessels/evaluation/get_metrics.py --ground_truth_dir gs://rslearn-eai/datasets/landsat_vessel_detection/detector/dataset_20240924/windows/labels_utm --predictions_dir gs://rslearn-eai/projects/landsat_evaluation/pipeline_results/jsons/
Expand Down
14 changes: 7 additions & 7 deletions rslp/landsat_vessels/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import json

# Landsat configuration
# Landsat config
LANDSAT_LAYER_NAME = "landsat"
LANDSAT_RESOLUTION = 15

# Detector configuration
# Detector config
LOCAL_FILES_DATASET_CONFIG = "data/landsat_vessels/predict_dataset_config.json"
AWS_DATASET_CONFIG = "data/landsat_vessels/predict_dataset_config_aws.json"
DETECT_MODEL_CONFIG = "data/landsat_vessels/config.yaml"
Expand All @@ -21,12 +21,12 @@
band["bands"][0] for band in json_data["layers"][LANDSAT_LAYER_NAME]["band_sets"]
]

# Classifier configuration
# Classifier config
CLASSIFY_MODEL_CONFIG = "landsat/recheck_landsat_labels/phase123_config.yaml"
CLASSIFY_WINDOW_SIZE = 64

# Filter configuration
INFRA_DISTANCE_THRESHOLD = 0.03 # unit: km, 30 meters
# Filter config
INFRA_THRESHOLD_KM = 0.03 # max-distance between marine infra and prediction

# Evaluation configuration
MATCH_DISTANCE_THRESHOLD = 0.1 # unit: km, 100 meters
# Evaluation config
MATCH_THRESHOLD_KM = 0.1 # max-distance between ground-truth and prediction
2 changes: 1 addition & 1 deletion rslp/landsat_vessels/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ services:
- base-image
ports:
- 5555:5555
shm_size: '15G' # This adds the shared memory size
shm_size: '15G' # Add the shared memory size
environment:
- NVIDIA_VISIBLE_DEVICES=all # Make all GPUs visible
deploy:
Expand Down
10 changes: 8 additions & 2 deletions rslp/landsat_vessels/evaluation/get_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from rslearn.utils import Projection, STGeometry
from upath import UPath

from rslp.landsat_vessels.config import MATCH_DISTANCE_THRESHOLD
from rslp.landsat_vessels.config import MATCH_THRESHOLD_KM
from rslp.utils.mp import init_mp


Expand Down Expand Up @@ -74,7 +74,7 @@ def process_window(
matched = False
for exp in expected_detections:
distance = haversine(pred, exp, unit=Unit.KILOMETERS)
if distance <= MATCH_DISTANCE_THRESHOLD:
if distance <= MATCH_THRESHOLD_KM:
matches += 1
if exp in current_missed_expected:
current_missed_expected.remove(exp)
Expand All @@ -83,6 +83,12 @@ def process_window(
if not matched:
unmatched_predicted += 1
missed_expected += len(current_missed_expected)
print(
f"window {window_name}, "
f"matches: {matches}, "
f"missed_expected: {missed_expected}, "
f"unmatched_predicted: {unmatched_predicted}"
)

return matches, missed_expected, unmatched_predicted

Expand Down
6 changes: 2 additions & 4 deletions rslp/landsat_vessels/evaluation/scenario_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import pandas as pd
from upath import UPath

RESULT_JSON_DIR = "gs://rslearn-eai/projects/2024_10_check_landsat/evaluation/jsons/"
TARGET_CSV_PATH = (
"gs://rslearn-eai/projects/2024_10_check_landsat/evaluation/csv/landsat_targets.csv"
)
RESULT_JSON_DIR = "gs://rslearn-eai/projects/landsat_evaluation/scenario_checks/jsons/"
TARGET_CSV_PATH = "gs://rslearn-eai/projects/landsat_evaluation/scenario_checks/csv/landsat_targets.csv"


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion rslp/landsat_vessels/job_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def launch_job(

with beaker.session():
env_vars = get_base_env_vars(use_weka_prefix=use_weka_prefix)
# Add AWS credentials for downloading data
env_vars.append(
EnvVar(
name="AWS_ACCESS_KEY_ID",
Expand Down Expand Up @@ -161,7 +162,7 @@ def launch_job(
if args.zip_dir:
try:
zip_dir_upath = UPath(args.zip_dir)
zip_paths = list(zip_dir_upath.glob("*.zip"))
paths = list(zip_dir_upath.glob("*.zip"))
except Exception:
# using S3 protocol to access WEKA is only supported on ai2 clusters
# as a workaround for other machines, we load the corresponding gcs bucket first
Expand Down
11 changes: 2 additions & 9 deletions rslp/landsat_vessels/predict_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
CLASSIFY_MODEL_CONFIG,
CLASSIFY_WINDOW_SIZE,
DETECT_MODEL_CONFIG,
INFRA_DISTANCE_THRESHOLD,
INFRA_THRESHOLD_KM,
LANDSAT_BANDS,
LANDSAT_LAYER_NAME,
LANDSAT_RESOLUTION,
Expand Down Expand Up @@ -397,9 +397,7 @@ def predict_pipeline(
crop_upath.mkdir(parents=True, exist_ok=True)

json_data = []
near_infra_filter = NearInfraFilter(
infra_distance_threshold=INFRA_DISTANCE_THRESHOLD
)
near_infra_filter = NearInfraFilter(infra_distance_threshold=INFRA_THRESHOLD_KM)
infra_detections = 0
for idx, detection in enumerate(detections):
# Get longitude/latitude.
Expand Down Expand Up @@ -468,11 +466,6 @@ def predict_pipeline(
b8_fname=str(b8_fname),
),
)
# Clean up
if tmp_scratch_dir:
tmp_scratch_dir.cleanup()
if tmp_zip_dir:
tmp_zip_dir.cleanup()

if json_path:
json_upath = UPath(json_path)
Expand Down

0 comments on commit 936aa91

Please sign in to comment.