Skip to content

Commit

Permalink
write crops only if specified path
Browse files Browse the repository at this point in the history
  • Loading branch information
yawenzzzz committed Oct 21, 2024
1 parent 2fd1847 commit 5a1149c
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions rslp/landsat_vessels/predict_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,10 @@ def predict_pipeline(

# Use temporary directory if scratch_path or crop_path are not specified.
if scratch_path is None:
with tempfile.TemporaryDirectory() as tmp_scratch_dir:
scratch_path = tmp_scratch_dir
with tempfile.TemporaryDirectory() as tmp_dir:
scratch_path = tmp_dir
else:
tmp_scratch_dir = None

if crop_path is None:
with tempfile.TemporaryDirectory() as tmp_crop_dir:
crop_path = tmp_crop_dir
else:
tmp_crop_dir = None
tmp_dir = None

ds_path = UPath(scratch_path)
ds_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -296,8 +290,9 @@ def predict_pipeline(

# Write JSON and crops.
step_start_time = time.time()
crop_path = UPath(crop_path)
crop_path.mkdir(parents=True, exist_ok=True)
if crop_path:
crop_path = UPath(crop_path)
crop_path.mkdir(parents=True, exist_ok=True)

json_data = []
for idx, detection in enumerate(detections):
Expand Down Expand Up @@ -329,13 +324,14 @@ def predict_pipeline(
[images["B4_sharp"], images["B3_sharp"], images["B2_sharp"]], axis=2
)

rgb_fname = crop_path / f"{idx}_rgb.png"
with rgb_fname.open("wb") as f:
Image.fromarray(rgb).save(f, format="PNG")
if crop_path:
rgb_fname = crop_path / f"{idx}_rgb.png"
with rgb_fname.open("wb") as f:
Image.fromarray(rgb).save(f, format="PNG")

b8_fname = crop_path / f"{idx}_b8.png"
with b8_fname.open("wb") as f:
Image.fromarray(images["B8"]).save(f, format="PNG")
b8_fname = crop_path / f"{idx}_b8.png"
with b8_fname.open("wb") as f:
Image.fromarray(images["B8"]).save(f, format="PNG")

# Get longitude/latitude.
src_geom = STGeometry(
Expand All @@ -361,10 +357,8 @@ def predict_pipeline(
time_profile["total"] = elapsed_time

# Clean up any temporary directories.
if tmp_scratch_dir:
tmp_scratch_dir.cleanup()
if tmp_crop_dir:
tmp_crop_dir.cleanup()
if tmp_dir:
tmp_dir.cleanup()

if json_path:
json_path = UPath(json_path)
Expand Down

0 comments on commit 5a1149c

Please sign in to comment.