Skip to content

Commit

Permalink
small qol updates
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 13, 2024
1 parent 81f069b commit d96c932
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 38 deletions.
27 changes: 9 additions & 18 deletions darts-acquisition/src/darts_acquisition/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,21 @@ def load_planet_scene(fpath: str | Path) -> xr.Dataset:
# Define band names and corresponding indices
planet_da = xr.open_dataarray(ps_image)

bands = {1: "blue", 2: "green", 3: "red", 4: "nir"}

# Create a list to hold datasets
datasets = [
planet_da.sel(band=index)
.assign_attrs(
# Create a dataset with the bands
bands = ["blue", "green", "red", "nir"]
ds_planet = (
planet_da.fillna(0).rio.write_nodata(0).astype("uint16").assign_coords({"band": bands}).to_dataset(dim="band")
)
for var in ds_planet.variables:
ds_planet[var].assign_attrs(
{
"long_name": f"PLANET {var.capitalize()}",
"data_source": "planet",
"planet_type": planet_type,
"long_name": f"PLANET {name.capitalize()}",
"units": "Reflectance",
}
)
.fillna(0)
.rio.write_nodata(0)
.astype("uint16")
.to_dataset(name=name)
.drop_vars("band")
for index, name in bands.items()
]

# Merge all datasets into one
ds_planet = xr.merge(datasets)
ds_planet.attrs["tile_id"] = fpath.parent.stem if planet_type == "orthotile" else fpath.stem
ds_planet.attrs = {"tile_id": fpath.parent.stem if planet_type == "orthotile" else fpath.stem}
logger.debug(f"Loaded Planet scene in {time.time() - start_time} seconds.")
return ds_planet

Expand Down
22 changes: 8 additions & 14 deletions darts-acquisition/src/darts_acquisition/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,15 @@ def load_s2_scene(fpath: str | Path) -> xr.Dataset:
# Define band names and corresponding indices
s2_da = xr.open_dataarray(s2_image)

bands = {1: "blue", 2: "green", 3: "red", 4: "nir"}

# Create a list to hold datasets
datasets = [
s2_da.sel(band=index)
.assign_attrs({"data_source": "s2", "long_name": f"Sentinel 2 {name.capitalize()}", "units": "Reflectance"})
.fillna(0)
.rio.write_nodata(0)
.astype("uint16")
.to_dataset(name=name)
.drop_vars("band")
for index, name in bands.items()
]
# Create a dataset with the bands
bands = ["blue", "green", "red", "nir"]
ds_s2 = s2_da.fillna(0).rio.write_nodata(0).astype("uint16").assign_coords({"band": bands}).to_dataset(dim="band")

for var in ds_s2.data_vars:
ds_s2[var].assign_attrs(
{"data_source": "s2", "long_name": f"Sentinel 2 {var.capitalize()}", "units": "Reflectance"}
)

ds_s2 = xr.merge(datasets)
planet_crop_id = fpath.stem
s2_tile_id = "_".join(s2_image.stem.split("_")[:3])
ds_s2.attrs["planet_crop_id"] = planet_crop_id
Expand Down
17 changes: 13 additions & 4 deletions darts/src/darts/legacy_pipeline/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class _BasePipeline:
tcvis_model_name: str = "RTS_v6_tcvis_s2native.pt"
notcvis_model_name: str = "RTS_v6_notcvis_s2native.pt"
device: Literal["cuda", "cpu", "auto"] | int | None = None
dask_worker: int = min(16, mp.cpu_count() - 1) # noqa: RUF009
ee_project: str | None = None
ee_use_highvolume: bool = True
patch_size: int = 1024
Expand Down Expand Up @@ -78,13 +79,17 @@ def run(self):
)

# Init Dask stuff with a context manager
with LocalCluster(n_workers=mp.cpu_count() - 1) as cluster, Client(cluster) as client:
logger.info(f"Using Dask client: {client}")
with LocalCluster(n_workers=self.dask_worker) as cluster, Client(cluster) as client:
logger.info(f"Using Dask client: {client} on cluster {cluster}")
logger.info(f"Dashboard available at: {client.dashboard_link}")
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
logger.info("Configured Rasterio with Dask")

# Iterate over all the data (_path_generator)
for fpath, outpath in self._path_generator():
n_tiles = 0
paths = sorted(self._path_generator())
logger.info(f"Found {len(paths)} tiles to process.")
for i, (fpath, outpath) in enumerate(paths):
try:
aqdata = self._get_data(fpath)
tile = self._preprocess(aqdata)
Expand All @@ -111,12 +116,16 @@ def run(self):
writer.export_probabilities(outpath)
writer.export_binarized(outpath)
writer.export_polygonized(outpath)
n_tiles += 1
logger.info(f"Processed sample {i + 1} of {len(paths)} '{fpath.resolve()}'.")
except KeyboardInterrupt:
logger.warning("Keyboard interrupt detected.\nExiting...")
break
raise KeyboardInterrupt
except Exception as e:
logger.warning(f"Could not process folder '{fpath.resolve()}'.\nSkipping...")
logger.exception(e)
else:
logger.info(f"Processed {n_tiles} tiles to {self.output_data_dir.resolve()}.")


# =============================================================================
Expand Down
2 changes: 1 addition & 1 deletion darts/src/darts/legacy_pipeline/planet_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _get_data(self, fpath: Path):

optical = load_planet_scene(fpath)
arcticdem = load_arcticdem_tile(
optical.odc.geobox, self.arcticdem_dir, resolution=10, buffer=ceil(self.tpi_outer_radius / 10 * sqrt(2))
optical.odc.geobox, self.arcticdem_dir, resolution=2, buffer=ceil(self.tpi_outer_radius / 2 * sqrt(2))
)
tcvis = load_tcvis(optical.odc.geobox, self.tcvis_dir)
data_masks = load_planet_masks(fpath)
Expand Down
7 changes: 6 additions & 1 deletion darts/src/darts/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,21 @@ def add_logging_handlers(command: str, console: Console, log_dir: Path):
log_dir (Path): The directory to save the logs to.
"""
import distributed
import lightning as L # noqa: N812
import torch
import torch.utils.data
import xarray as xr

log_dir.mkdir(parents=True, exist_ok=True)
current_time = time.strftime("%Y-%m-%d_%H-%M-%S")

# Configure the rich console handler
rich_handler = RichHandler(
console=console, rich_tracebacks=True, tracebacks_suppress=[cyclopts, L, torch, torch.utils.data]
console=console,
rich_tracebacks=True,
tracebacks_suppress=[cyclopts, L, torch, torch.utils.data, xr, distributed],
tracebacks_show_locals=True,
)
rich_handler.setFormatter(
logging.Formatter(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"wandb>=0.18.7",
"torchmetrics>=1.6.0",
"seaborn>=0.13.2",
"distributed>=2024.12.0",
]
readme = "README.md"
requires-python = ">= 3.11"
Expand Down

0 comments on commit d96c932

Please sign in to comment.