From 4a074d9871f88e3cd6186c4bde764a4f5fdd927e Mon Sep 17 00:00:00 2001 From: Wei Ji Date: Thu, 13 Jun 2019 11:51:17 +0200 Subject: [PATCH] :ok_hand: Return of the gapfilled REMA training tiles Neural network wasn't training properly, and I tracked it down to the REMA input rasters having low NaN-like values... Found out proper way to get dask DataArray masks using dask.array.ma module, and so we can reintroduce gapfilling in data_prep.selective_tile, this time using dask/xarray to vectorize the operations. The gapfilled raster is also interpolated better along the edges as in 7fd33450cef6638534524101c360ccf5db98e6d9 which might help with the neural network training later. Quilt hash updated from 9c8cb530df6340e257e18008b59b9d7b5f701fd9e5cef2c8436984ae49cff237 to b0b090ca35271d41ea1cf5e6afa0e6c6a3da34193c00444963dde7ad20eb7331. Not passing in a gapfill_raster_filepath (when it is needed) now errors out with nicer debugging plots that have EPSG:3031 projected coordinates on the axes! --- data_prep.ipynb | 74 ++++++++++++++++++++++++++++++++++--------------- data_prep.py | 53 +++++++++++++++++++++++++++-------- 2 files changed, 92 insertions(+), 35 deletions(-) diff --git a/data_prep.ipynb b/data_prep.ipynb index 4763f87..8306fb2 100644 --- a/data_prep.ipynb +++ b/data_prep.ipynb @@ -55,7 +55,6 @@ "import salem\n", "\n", "import dask\n", - "import dask.diagnostics\n", "import geopandas as gpd\n", "import pygmt as gmt\n", "import IPython.display\n", @@ -1211,8 +1210,9 @@ " for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax\n", " ]\n", "\n", + " # Retrieve tiles from the main raster\n", " with xr.open_rasterio(\n", - " filepath, chunks=None if out_shape is None else {}, cache=False\n", + " filepath, chunks=None if out_shape is None else {}\n", " ) as dataset:\n", " print(f\"Tiling: {filepath} ... \", end=\"\")\n", "\n", @@ -1231,21 +1231,50 @@ " )\n", " for da in daarray_list\n", " ]\n", - " daarray_stack = dask.array.stack(seq=daarray_list)\n", + " daarray_stack = dask.array.ma.masked_values(\n", + " x=dask.array.stack(seq=daarray_list), value=dataset.nodatavals\n", + " )\n", "\n", " assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width)\n", " assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel)\n", " assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window)\n", " print(\"done!\")\n", "\n", - " with dask.diagnostics.ProgressBar(minimum=5.0):\n", - " try:\n", - " out_tiles = daarray_stack.compute().astype(dtype=np.float32)\n", - " assert not np.isnan(out_tiles).any() # check that there are no NAN values\n", - " except AssertionError:\n", - " raise NotImplementedError(\"gapfilling on dask xarray not yet implemented\")\n", - " finally:\n", - " return out_tiles" + " out_tiles = dask.array.ma.getdata(daarray_stack).compute().astype(dtype=np.float32)\n", + " mask = dask.array.ma.getmaskarray(daarray_stack).compute()\n", + "\n", + " # Gapfill main raster if there are blank spaces\n", + " if mask.any(): # check that there are no NAN values\n", + " nan_grid_indexes = np.argwhere(mask.any(axis=(-3, -2, -1))).ravel()\n", + "\n", + " # Replace pixels from another raster if available, else raise error\n", + " if gapfill_raster_filepath is not None:\n", + " with xr.open_rasterio(gapfill_raster_filepath, chunks={}) as dataset2:\n", + " daarray_list2 = [\n", + " dataset2.interp_like(daarray_list[idx].squeeze(), method=\"linear\")\n", + " for idx in nan_grid_indexes\n", + " ]\n", + " daarray_stack2 = dask.array.ma.masked_values(\n", + " x=dask.array.stack(seq=daarray_list2), value=dataset2.nodatavals\n", + " )\n", + "\n", + " fill_tiles = (\n", + " dask.array.ma.getdata(daarray_stack2).compute().astype(dtype=np.float32)\n", + " )\n", + " mask2 = dask.array.ma.getmaskarray(daarray_stack2).compute()\n", + "\n", + " for i, array2 in enumerate(fill_tiles):\n", + " idx = nan_grid_indexes[i]\n", + " np.copyto(dst=out_tiles[idx], src=array2, where=mask[idx])\n", + " assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill\n", + "\n", + " else:\n", + " for i in nan_grid_indexes:\n", + " daarray_list[i].plot()\n", + " plt.show()\n", + " print(f\"WARN: Tiles have missing data, try pass in gapfill_raster_filepath\")\n", + "\n", + " return out_tiles" ] }, { @@ -1353,7 +1382,7 @@ " filepath=\"misc/REMA_100m_dem.tif\",\n", " window_bounds=window_bounds_concat,\n", " padding=1000,\n", - " # gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n", + " gapfill_raster_filepath=\"misc/REMA_200m_dem_filled.tif\",\n", ")\n", "print(rema.shape, rema.dtype)" ] @@ -1390,7 +1419,6 @@ "output_type": "stream", "text": [ "Tiling: misc/MEaSUREs_IceFlowSpeed_450m.tif ... done!\n", - "[########################################] | 100% Completed | 22.1s\n", "(2347, 1, 20, 20) float32\n" ] } @@ -1485,7 +1513,7 @@ "name": "stdin", "output_type": "stream", "text": [ - "Enter the code from the webpage: eyJjb2RlIjogIjg4ODljZTY0LTA1ODMtNGIxYS04YjE2LTQ0MjFjZDViMTQxNCIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n" + "Enter the code from the webpage: eyJjb2RlIjogIjg0OTA5ODJlLTM0NWYtNDljNC04Y2Q0LTUwY2FlMjhiOWNlZSIsICJpZCI6ICIyOWI4YzUyNS1lZmM1LTQ5NTItOGQ4Yy03NzQyYTg1YmI1MmEifQ==\n" ] } ], @@ -1557,32 +1585,32 @@ "name": "stderr", "output_type": "stream", "text": [ - " 96%|█████████▌| 6.47G/6.74G [00:01<04:17, 1.04MB/s] " + " 94%|█████████▍| 6.35G/6.74G [00:01<01:40, 3.91MB/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Fragment 16ed97cce049cd2859a379964a8fa7575d9b871ec126d33c824542b126eab177 already uploaded; skipping.\n", - "Fragment c665815f043b87cfe94d51caabd1b57d8f6f6773d632503de6db0725f20d391c already uploaded; skipping.\n", - "Fragment 1f66fe557ce079c063597f0b04d15862f67af2c9dd4f286801851e0c71f0e869 already uploaded; skipping.\n", - "Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n", + "Fragment 2b994ae9d13f6c01ce00c426f52c6dce0c4681f8c8aaf8a96608fd3d62f3a269 already uploaded; skipping.\n", "Fragment 28e2ca7656d61b0bc7f8f8c1db41914023e0cab1634e0ee645f38a87d894b416 already uploaded; skipping.\n", + "Fragment 1f66fe557ce079c063597f0b04d15862f67af2c9dd4f286801851e0c71f0e869 already uploaded; skipping.\n", + "Fragment f1f660d1287225c30b8b2cbf2a727283d807a1ee443153519cbf407a08937965 already uploaded; skipping.\n", "Fragment 6ef3a2439a508de0919bd33a713976b5aa4895929a9d7981c09f722ce702e16a already uploaded; skipping.\n", "Fragment 80c9fa41ccc69be1d2cd4a367d56168321d1079e7260a1996089810db25172f6 already uploaded; skipping.\n", "Fragment ca9c41a8dd56097e40865d2e65c65d299c22fc17608ddb6c604c532a69936307 already uploaded; skipping.\n", - "Fragment 04a52d9a52901d8f7f74fd9ef6fc9fc215d6c9d787540511f68630f5cca16094 already uploaded; skipping.\n", - "Fragment f1f660d1287225c30b8b2cbf2a727283d807a1ee443153519cbf407a08937965 already uploaded; skipping.\n", "Fragment f750893861a1a268c8ffe0ba7db36c933223bbf5fcbb786ecef3f052b20f9b8a already uploaded; skipping.\n", - "Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n" + "Fragment c665815f043b87cfe94d51caabd1b57d8f6f6773d632503de6db0725f20d391c already uploaded; skipping.\n", + "Fragment 16ed97cce049cd2859a379964a8fa7575d9b871ec126d33c824542b126eab177 already uploaded; skipping.\n", + "Fragment e6b139801bf4541f1e4989a8aa8b26ab37eca81bb5eaffa8028b744782455db0 already uploaded; skipping.\n", + "Fragment 4a4efc3a84204c3d67887e8d7fa1186467b51e696451f2832ebbea3ca491c8a8 already uploaded; skipping.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 6.74G/6.74G [00:03<00:00, 1.77GB/s]\n" + "100%|██████████| 6.74G/6.74G [00:09<00:00, 688MB/s] \n" ] }, { diff --git a/data_prep.py b/data_prep.py index 20034c9..cb9bdf8 100644 --- a/data_prep.py +++ b/data_prep.py @@ -39,7 +39,6 @@ import salem import dask -import dask.diagnostics import geopandas as gpd import pygmt as gmt import IPython.display @@ -621,8 +620,9 @@ def selective_tile( for x0, y0, x1, y1 in window_bounds # xmin, ymin, xmax, ymax ] + # Retrieve tiles from the main raster with xr.open_rasterio( - filepath, chunks=None if out_shape is None else {}, cache=False + filepath, chunks=None if out_shape is None else {} ) as dataset: print(f"Tiling: {filepath} ... ", end="") @@ -641,21 +641,50 @@ def selective_tile( ) for da in daarray_list ] - daarray_stack = dask.array.stack(seq=daarray_list) + daarray_stack = dask.array.ma.masked_values( + x=dask.array.stack(seq=daarray_list), value=dataset.nodatavals + ) assert daarray_stack.ndim == 4 # check that shape is like (m, 1, height, width) assert daarray_stack.shape[1] == 1 # channel-first (assuming only 1 channel) assert not 0 in daarray_stack.shape # ensure no empty dimensions (bad window) print("done!") - with dask.diagnostics.ProgressBar(minimum=5.0): - try: - out_tiles = daarray_stack.compute().astype(dtype=np.float32) - assert not np.isnan(out_tiles).any() # check that there are no NAN values - except AssertionError: - raise NotImplementedError("gapfilling on dask xarray not yet implemented") - finally: - return out_tiles + out_tiles = dask.array.ma.getdata(daarray_stack).compute().astype(dtype=np.float32) + mask = dask.array.ma.getmaskarray(daarray_stack).compute() + + # Gapfill main raster if there are blank spaces + if mask.any(): # check that there are no NAN values + nan_grid_indexes = np.argwhere(mask.any(axis=(-3, -2, -1))).ravel() + + # Replace pixels from another raster if available, else raise error + if gapfill_raster_filepath is not None: + with xr.open_rasterio(gapfill_raster_filepath, chunks={}) as dataset2: + daarray_list2 = [ + dataset2.interp_like(daarray_list[idx].squeeze(), method="linear") + for idx in nan_grid_indexes + ] + daarray_stack2 = dask.array.ma.masked_values( + x=dask.array.stack(seq=daarray_list2), value=dataset2.nodatavals + ) + + fill_tiles = ( + dask.array.ma.getdata(daarray_stack2).compute().astype(dtype=np.float32) + ) + mask2 = dask.array.ma.getmaskarray(daarray_stack2).compute() + + for i, array2 in enumerate(fill_tiles): + idx = nan_grid_indexes[i] + np.copyto(dst=out_tiles[idx], src=array2, where=mask[idx]) + assert not (mask[idx] & mask2[i]).any() # Ensure no NANs after gapfill + + else: + for i in nan_grid_indexes: + daarray_list[i].plot() + plt.show() + print(f"WARN: Tiles have missing data, try pass in gapfill_raster_filepath") + + return out_tiles # %% @@ -695,7 +724,7 @@ def selective_tile( filepath="misc/REMA_100m_dem.tif", window_bounds=window_bounds_concat, padding=1000, - # gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif", + gapfill_raster_filepath="misc/REMA_200m_dem_filled.tif", ) print(rema.shape, rema.dtype)