diff --git a/data_prep.ipynb b/data_prep.ipynb index 4763f87..8306fb2 100644 --- a/data_prep.ipynb +++ b/data_prep.ipynb @@ -55,7 +55,6 @@ "import salem\n", "\n", "import dask\n", - "import dask.diagnostics\n", "import geopandas as gpd\n", "import pygmt as gmt\n", "import IPython.display\n", @@ -1211,8 +1210,9 @@ " for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax\n", " ]\n", "\n", + " # Retrieve tiles from the main raster\n", " with xr.open_rasterio(\n", - " filepath, chunks=None if out_shape is None else {}, cache=False\n", + " filepath, chunks=None if out_shape is None else {}\n", " ) as dataset:\n", " print(f\"Tiling: {filepath} ... \", end=\"\")\n", "\n", @@ -1231,21 +1231,50 @@ " )\n", " for da in daarray_list\n", " ]\n", - " daarray_stack = dask.array.stack(seq=daarray_list)\n", + " daarray_stack = dask.array.ma.masked_values(\n", + " x=dask.array.stack(seq=daarray_list), value=dataset.nodatavals\n", + " )\n", "\n", " assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)\n", " assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)\n", " assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)\n", " print(\"done!\")\n", "\n", - " with dask.diagnostics.ProgressBar(minimum=5.0):\n", - " try:\n", - " out_tiles = daarray_stack.compute().astype(dtype=np.float32)\n", - " assert not np.isnan(out_tiles).any() # check that there are no NAN values\n", - " except AssertionError:\n", - " raise NotImplementedError(\"gapfilling on dask xarray not yet implemented\")\n", - " finally:\n", - " return out_tiles" + " out_tiles = dask.array.ma.getdata(daarray_stack).compute().astype(dtype=np.float32)\n", + " mask = dask.array.ma.getmaskarray(daarray_stack).compute()\n", + "\n", + " # Gapfill main raster if there are blank spaces\n", + " if mask.any(): # check that there are no NAN values\n", + " nan_grid_indexes = np.argwhere(mask.any(axis=(-3, -2, -1))).ravel()\n", + "\n", + " # Replace pixels from another raster if available, else raise error\n", + " if gapfill_raster_filepath is not None:\n", + " with xr.open_rasterio(gapfill_raster_filepath, chunks={}) as dataset2:\n", + " daarray_list2 = [\n", + " dataset2.interp_like(daarray_list[idx].squeeze(), method=\"linear\")\n", + " for idx in nan_grid_indexes\n", + " ]\n", + " daarray_stack2 = dask.array.ma.masked_values(\n", + " x=dask.array.stack(seq=daarray_list2), value=dataset2.nodatavals\n", + " )\n", + "\n", + " fill_tiles = (\n", + " dask.array.ma.getdata(daarray_stack2).compute().astype(dtype=np.float32)\n", + " )\n", + " mask2 = dask.array.ma.getmaskarray(daarray_stack2).compute()\n", + "\n", + " for i, array2 in enumerate(fill_tiles):\n", + " idx = nan_grid_indexes[i]\n", + " np.copyto(dst=out_tiles[idx], src=array2, where=mask[idx])\n", + " assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill\n", + "\n", + " else:\n", + " for i in nan_grid_indexes:\n", + " daarray_list[i].plot()\n", + " plt.show()\n", + " print(f\"WARN: Tiles have missing data, try pass in gapfill_raster_filepath\")\n", + "\n", + " return out_tiles" ] }, { @@ -1353,7 +1382,7 @@ " filepath=\"misc/REMA_100m_dem.tif\",\n", " window_bounds=window_bounds_concat,\n", " padding=1000,\n", - " # gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n", + " gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n", ")\n", "print(rema.shape, rema.dtype)" ] @@ -1390,7 +1419,6 @@ "output_type": "stream", "text": [ "Tiling: misc/MEaSUREs_IceFlowSpeed_450m.tif ... done!\n", - "[########################################] | 100% Completed | 22.1s\n", "(2347, 1, 20, 20) float32\n" ] } @@ -1485,7 +1513,7 @@ "name": "stdin", "output_type": "stream", "text": [ - "Enter the code from the webpage: eyJjb2RlIjogIjg4ODljZTY0LTA1ODMtNGIxYS04YjE2LTQ0MjFjZDViMTQxNCIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n" + "Enter the code from the webpage: eyJjb2RlIjogIjg0OTA5ODJlLTM0NWYtNDljNC04Y2Q0LTUwY2FlMjhiOWNlZSIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n" ] } ], @@ -1557,32 +1585,32 @@ "name": "stderr", "output_type": "stream", "text": [ - " 96%|█████████▌| 6.47G/6.74G [00:01<04:17, 1.04MB/s] " + " 94%|█████████▍| 6.35G/6.74G [00:01<01:40, 3.91MB/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Fragment 16ed97cce049cd2859a379964a8fa7575d9b871ec126d33c824542b126eab177 already uploaded; skipping.\n", - "Fragment c665815f043b87cfe94d51caabd1b57d8f6f6773d632503de6db0725f20d391c already uploaded; skipping.\n", - "Fragment 1f66fe557ce079c063597f0b04d15862f67af2c9dd4f286801851e0c71f0e869 already uploaded; skipping.\n", - "Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n", + "Fragment 2b994ae9d13f6c01ce00c426f52c6dce0c4681f8c8aaf8a96608fd3d62f3a269 already uploaded; skipping.\n", "Fragment 28e2ca7656d61b0bc7f8f8c1db41914023e0cab1634e0ee645f38a87d894b416 already uploaded; skipping.\n", + "Fragment 1f66fe557ce079c063597f0b04d15862f67af2c9dd4f286801851e0c71f0e869 already uploaded; skipping.\n", + "Fragment f1f660d1287225c30b8b2cbf2a727283d807a1ee443153519cbf407a08937965 already uploaded; skipping.\n", "Fragment 6ef3a2439a508de0919bd33a713976b5aa4895929a9d7981c09f722ce702e16a already uploaded; skipping.\n", "Fragment 80c9fa41ccc69be1d2cd4a367d56168321d1079e7260a1996089810db25172f6 already uploaded; skipping.\n", "Fragment ca9c41a8dd56097e40865d2e65c65d299c22fc17608ddb6c604c532a69936307 already uploaded; skipping.\n", - "Fragment 04a52d9a52901d8f7f74fd9ef6fc9fc215d6c9d787540511f68630f5cca16094 already uploaded; skipping.\n", - "Fragment f1f660d1287225c30b8b2cbf2a727283d807a1ee443153519cbf407a08937965 already uploaded; skipping.\n", "Fragment f750893861a1a268c8ffe0ba7db36c933223bbf5fcbb786ecef3f052b20f9b8a already uploaded; skipping.\n", - "Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n" + "Fragment c665815f043b87cfe94d51caabd1b57d8f6f6773d632503de6db0725f20d391c already uploaded; skipping.\n", + "Fragment 16ed97cce049cd2859a379964a8fa7575d9b871ec126d33c824542b126eab177 already uploaded; skipping.\n", + "Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n", + "Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 6.74G/6.74G [00:03<00:00, 1.77GB/s]\n" + "100%|██████████| 6.74G/6.74G [00:09<00:00, 688MB/s] \n" ] }, { diff --git a/data_prep.py b/data_prep.py index 20034c9..cb9bdf8 100644 --- a/data_prep.py +++ b/data_prep.py @@ -39,7 +39,6 @@ import salem import dask -import dask.diagnostics import geopandas as gpd import pygmt as gmt import IPython.display @@ -621,8 +620,9 @@ def selective_tile( for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax ] + # Retrieve tiles from the main raster with xr.open_rasterio( - filepath, chunks=None if out_shape is None else {}, cache=False + filepath, chunks=None if out_shape is None else {} ) as dataset: print(f"Tiling: {filepath} ... ", end="") @@ -641,21 +641,50 @@ def selective_tile( ) for da in daarray_list ] - daarray_stack = dask.array.stack(seq=daarray_list) + daarray_stack = dask.array.ma.masked_values( + x=dask.array.stack(seq=daarray_list), value=dataset.nodatavals + ) assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width) assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel) assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window) print("done!") - with dask.diagnostics.ProgressBar(minimum=5.0): - try: - out_tiles = daarray_stack.compute().astype(dtype=np.float32) - assert not np.isnan(out_tiles).any() # check that there are no NAN values - except AssertionError: - raise NotImplementedError("gapfilling on dask xarray not yet implemented") - finally: - return out_tiles + out_tiles = dask.array.ma.getdata(daarray_stack).compute().astype(dtype=np.float32) + mask = dask.array.ma.getmaskarray(daarray_stack).compute() + + # Gapfill main raster if there are blank spaces + if mask.any(): # check that there are no NAN values + nan_grid_indexes = np.argwhere(mask.any(axis=(-3, -2, -1))).ravel() + + # Replace pixels from another raster if available, else raise error + if gapfill_raster_filepath is not None: + with xr.open_rasterio(gapfill_raster_filepath, chunks={}) as dataset2: + daarray_list2 = [ + dataset2.interp_like(daarray_list[idx].squeeze(), method="linear") + for idx in nan_grid_indexes + ] + daarray_stack2 = dask.array.ma.masked_values( + x=dask.array.stack(seq=daarray_list2), value=dataset2.nodatavals + ) + + fill_tiles = ( + dask.array.ma.getdata(daarray_stack2).compute().astype(dtype=np.float32) + ) + mask2 = dask.array.ma.getmaskarray(daarray_stack2).compute() + + for i, array2 in enumerate(fill_tiles): + idx = nan_grid_indexes[i] + np.copyto(dst=out_tiles[idx], src=array2, where=mask[idx]) + assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill + + else: + for i in nan_grid_indexes: + daarray_list[i].plot() + plt.show() + print(f"WARN: Tiles have missing data, try pass in gapfill_raster_filepath") + + return out_tiles # %% @@ -695,7 +724,7 @@ def selective_tile( filepath="misc/REMA_100m_dem.tif", window_bounds=window_bounds_concat, padding=1000, - # gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif", + gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif", ) print(rema.shape, rema.dtype)