diff --git a/satpy/tests/writer_tests/test_geotiff.py b/satpy/tests/writer_tests/test_geotiff.py index 8ca886e74d..7d8570c5bb 100644 --- a/satpy/tests/writer_tests/test_geotiff.py +++ b/satpy/tests/writer_tests/test_geotiff.py @@ -100,7 +100,7 @@ def test_float_write(self): datasets = self._get_test_datasets() w = GeoTIFFWriter() w.save_datasets(datasets, - floating_point=True, + dtype=np.float32, enhancement_config=False, base_dir=self.base_dir) diff --git a/satpy/writers/geotiff.py b/satpy/writers/geotiff.py index 70e84a9e16..cc805190ec 100644 --- a/satpy/writers/geotiff.py +++ b/satpy/writers/geotiff.py @@ -35,6 +35,20 @@ LOG = logging.getLogger(__name__) +# Map numpy data types to GDAL data types +NP2GDAL = { + np.float32: gdal.GDT_Float32, + np.float64: gdal.GDT_Float64, + np.uint8: gdal.GDT_Byte, + np.uint16: gdal.GDT_UInt16, + np.uint32: gdal.GDT_UInt32, + np.int16: gdal.GDT_Int16, + np.int32: gdal.GDT_Int32, + np.complex64: gdal.GDT_CFloat32, + np.complex128: gdal.GDT_CFloat64, +} + + class GeoTIFFWriter(ImageWriter): """Writer to save GeoTIFF images. @@ -45,8 +59,8 @@ class GeoTIFFWriter(ImageWriter): Un-enhanced float geotiff with NaN for fill values: - scn.save_datasets(writer='geotiff', floating_point=True, - enhancement_config=False, fill_value=np.nan) + scn.save_datasets(writer='geotiff', dtype=np.float32, + enhancement_config=False) """ @@ -73,14 +87,12 @@ class GeoTIFFWriter(ImageWriter): "pixeltype", "copy_src_overviews", ) - def __init__(self, floating_point=False, tags=None, **kwargs): + def __init__(self, dtype=None, tags=None, **kwargs): ImageWriter.__init__(self, default_config_filename="writers/geotiff.yaml", **kwargs) - self.floating_point = bool(self.info.get( - "floating_point", None) if floating_point is None else - floating_point) + self.dtype = self.info.get("dtype") if dtype is None else dtype self.tags = self.info.get("tags", None) if tags is None else tags if self.tags is None: @@ -100,97 +112,77 @@ def separate_init_kwargs(cls, kwargs): # FUTURE: Don't pass Scene.save_datasets kwargs to init and here init_kwargs, kwargs = super(GeoTIFFWriter, cls).separate_init_kwargs( kwargs) - for kw in ['floating_point', 'tags']: + for kw in ['dtype', 'tags']: if kw in kwargs: init_kwargs[kw] = kwargs.pop(kw) return init_kwargs, kwargs - def _gdal_write_datasets(self, dst_ds, datasets, opacity): - """Write *datasets* in a gdal raster structure *dts_ds*, using - *opacity* as alpha value for valid data, and *fill_value*. - """ - def _write_array(bnd, chn): - bnd.WriteArray(chn.values) - - # queue up data writes so we don't waste computation time - delayed = [] + def _gdal_write_datasets(self, dst_ds, datasets): + """Write datasets in a gdal raster structure dts_ds""" for i, band in enumerate(datasets['bands']): chn = datasets.sel(bands=band) bnd = dst_ds.GetRasterBand(i + 1) bnd.SetNoDataValue(0) - delay = dask.delayed(_write_array)(bnd, chn) - delayed.append(delay) - dask.compute(*delayed) - - def _create_file(self, filename, img, gformat, g_opts, opacity, - datasets, mode): - raster = gdal.GetDriverByName("GTiff") - - if mode == "L": - dst_ds = raster.Create(filename, img.width, img.height, 1, - gformat, g_opts) - self._gdal_write_datasets(dst_ds, datasets, opacity) - elif mode == "LA": - g_opts.append("ALPHA=YES") - dst_ds = raster.Create(filename, img.width, img.height, 2, gformat, - g_opts) - self._gdal_write_datasets(dst_ds, datasets, datasets) - elif mode == "RGB": - dst_ds = raster.Create(filename, img.width, img.height, 3, - gformat, g_opts) - self._gdal_write_datasets(dst_ds, datasets, datasets) - - elif mode == "RGBA": - g_opts.append("ALPHA=YES") - dst_ds = raster.Create(filename, img.width, img.height, 4, gformat, - g_opts) - - self._gdal_write_datasets(dst_ds, datasets, datasets) - else: - raise NotImplementedError( - "Saving to GeoTIFF using image mode %s is not implemented." % - mode) - - # Create raster GeoTransform based on upper left corner and pixel - # resolution ... if not overwritten by argument geotransform. - if "area" not in img.data.attrs: - LOG.warning("No 'area' metadata found in image") - else: - area = img.data.attrs["area"] + bnd.WriteArray(chn.values) + + def _gdal_write_geo(self, dst_ds, area): + try: + geotransform = [area.area_extent[0], area.pixel_size_x, 0, + area.area_extent[3], 0, -area.pixel_size_y] + dst_ds.SetGeoTransform(geotransform) + srs = osr.SpatialReference() + + srs.ImportFromProj4(area.proj4_string) + srs.SetProjCS(area.proj_id) + try: + srs.SetWellKnownGeogCS(area.proj_dict['ellps']) + except KeyError: + pass try: - geotransform = [area.area_extent[0], area.pixel_size_x, 0, - area.area_extent[3], 0, -area.pixel_size_y] - dst_ds.SetGeoTransform(geotransform) - srs = osr.SpatialReference() - - srs.ImportFromProj4(area.proj4_string) - srs.SetProjCS(area.proj_id) - try: - srs.SetWellKnownGeogCS(area.proj_dict['ellps']) - except KeyError: - pass - try: - # Check for epsg code. - srs.ImportFromEPSG(int( - area.proj_dict['init'].lower().split('epsg:')[1])) - except (KeyError, IndexError): - pass - srs = srs.ExportToWkt() - dst_ds.SetProjection(srs) - except AttributeError: - LOG.warning( - "Can't save geographic information to geotiff, unsupported area type") - - tags = self.tags.copy() - if "start_time" in img.data.attrs: - tags.update({'TIFFTAG_DATETIME': img.data.attrs["start_time"].strftime( - "%Y:%m:%d %H:%M:%S")}) - - dst_ds.SetMetadata(tags, '') - - def save_image(self, img, filename=None, floating_point=None, - compute=True, **kwargs): + # Check for epsg code. + srs.ImportFromEPSG(int( + area.proj_dict['init'].lower().split('epsg:')[1])) + except (KeyError, IndexError): + pass + srs = srs.ExportToWkt() + dst_ds.SetProjection(srs) + except AttributeError: + LOG.warning( + "Can't save geographic information to geotiff, unsupported area type") + + def _create_file(self, filename, img, gformat, g_opts, datasets, mode): + num_bands = len(mode) + if mode[-1] == 'A': + g_opts.append("ALPHA=YES") + + def _delayed_create(create_opts, datasets, area, start_time, tags): + raster = gdal.GetDriverByName("GTiff") + dst_ds = raster.Create(*create_opts) + self._gdal_write_datasets(dst_ds, datasets) + + # Create raster GeoTransform based on upper left corner and pixel + # resolution ... if not overwritten by argument geotransform. + if "area" is None: + LOG.warning("No 'area' metadata found in image") + else: + self._gdal_write_geo(dst_ds, area) + + if start_time is not None: + tags.update({'TIFFTAG_DATETIME': start_time.strftime( + "%Y:%m:%d %H:%M:%S")}) + + dst_ds.SetMetadata(tags, '') + + create_opts = (filename, img.width, img.height, num_bands, gformat, g_opts) + delayed = dask.delayed(_delayed_create)( + create_opts, datasets, img.data.attrs.get('area'), + img.data.attrs.get('start_time'), + self.tags.copy()) + return delayed + + def save_image(self, img, filename=None, dtype=None, fill_value=None, + floating_point=None, compute=True, **kwargs): """Save the image to the given *filename* in geotiff_ format. `floating_point` allows the saving of 'L' mode images in floating point format if set to True. @@ -205,43 +197,51 @@ def save_image(self, img, filename=None, floating_point=None, if k in self.GDAL_OPTIONS: gdal_options[k] = kwargs[k] - floating_point = floating_point if floating_point is not None else self.floating_point + if floating_point is not None: + import warnings + warnings.warn("'floating_point' is deprecated, use" + "'dtype=np.float64' instead.", + DeprecationWarning) + dtype = np.float64 + dtype = dtype if dtype is not None else self.dtype + if dtype is None: + dtype = np.uint8 if "alpha" in kwargs: raise ValueError( - "Keyword 'alpha' is automatically set and should not be specified") - if floating_point: + "Keyword 'alpha' is automatically set based on 'fill_value' " + "and should not be specified") + if np.issubdtype(dtype, np.floating): if img.mode != "L": raise ValueError("Image must be in 'L' mode for floating " "point geotiff saving") - fill_value = np.nan - datasets, mode = img._finalize(fill_value=fill_value, - dtype=np.float64) - gformat = gdal.GDT_Float64 - opacity = 0 - else: - nbits = int(gdal_options.get("nbits", "8")) - if nbits > 16: - dtype = np.uint32 - gformat = gdal.GDT_UInt32 - elif nbits > 8: - dtype = np.uint16 - gformat = gdal.GDT_UInt16 - else: - dtype = np.uint8 - gformat = gdal.GDT_Byte - opacity = np.iinfo(dtype).max - datasets, mode = img._finalize(dtype=dtype) - - LOG.debug("Saving to GeoTiff: %s", filename) - - g_opts = ["{0}={1}".format(k.upper(), str(v)) - for k, v in gdal_options.items()] - - ensure_dir(filename) - delayed = dask.delayed(self._create_file)(filename, img, gformat, - g_opts, opacity, datasets, - mode) - if compute: - return delayed.compute() - return delayed + if fill_value is None: + LOG.debug("Alpha band not supported for float geotiffs, " + "setting fill value to 'NaN'") + fill_value = np.nan + + try: + import rasterio # noqa + # we can use the faster rasterio-based save + return img.save(filename, fformat='tif', fill_value=fill_value, + dtype=dtype, compute=compute, **gdal_options) + except ImportError: + LOG.warning("Using legacy/slower geotiff save method, install " + "'rasterio' for faster saving.") + # force to numpy dtype object + dtype = np.dtype(dtype) + gformat = NP2GDAL[dtype.type] + + gdal_options['nbits'] = int(gdal_options.get('nbits', + dtype.itemsize * 8)) + datasets, mode = img._finalize(fill_value=fill_value, dtype=dtype) + LOG.debug("Saving to GeoTiff: %s", filename) + g_opts = ["{0}={1}".format(k.upper(), str(v)) + for k, v in gdal_options.items()] + + ensure_dir(filename) + delayed = self._create_file(filename, img, gformat, g_opts, + datasets, mode) + if compute: + return delayed.compute() + return delayed