Skip to content

Commit

Permalink
Merge pull request #252 from pytroll/feature-geotiff-refactor
Browse files Browse the repository at this point in the history
Use rasterio to save geotiffs when available
  • Loading branch information
djhoese authored Apr 16, 2018
2 parents 1fe7e8c + 4e7a699 commit 2d1b12c
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 122 deletions.
2 changes: 1 addition & 1 deletion satpy/tests/writer_tests/test_geotiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_float_write(self):
datasets = self._get_test_datasets()
w = GeoTIFFWriter()
w.save_datasets(datasets,
floating_point=True,
dtype=np.float32,
enhancement_config=False,
base_dir=self.base_dir)

Expand Down
242 changes: 121 additions & 121 deletions satpy/writers/geotiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@
LOG = logging.getLogger(__name__)


# Map numpy data types to GDAL data types
NP2GDAL = {
np.float32: gdal.GDT_Float32,
np.float64: gdal.GDT_Float64,
np.uint8: gdal.GDT_Byte,
np.uint16: gdal.GDT_UInt16,
np.uint32: gdal.GDT_UInt32,
np.int16: gdal.GDT_Int16,
np.int32: gdal.GDT_Int32,
np.complex64: gdal.GDT_CFloat32,
np.complex128: gdal.GDT_CFloat64,
}


class GeoTIFFWriter(ImageWriter):

"""Writer to save GeoTIFF images.
Expand All @@ -45,8 +59,8 @@ class GeoTIFFWriter(ImageWriter):
Un-enhanced float geotiff with NaN for fill values:
scn.save_datasets(writer='geotiff', floating_point=True,
enhancement_config=False, fill_value=np.nan)
scn.save_datasets(writer='geotiff', dtype=np.float32,
enhancement_config=False)
"""

Expand All @@ -73,14 +87,12 @@ class GeoTIFFWriter(ImageWriter):
"pixeltype",
"copy_src_overviews", )

def __init__(self, floating_point=False, tags=None, **kwargs):
def __init__(self, dtype=None, tags=None, **kwargs):
ImageWriter.__init__(self,
default_config_filename="writers/geotiff.yaml",
**kwargs)

self.floating_point = bool(self.info.get(
"floating_point", None) if floating_point is None else
floating_point)
self.dtype = self.info.get("dtype") if dtype is None else dtype
self.tags = self.info.get("tags",
None) if tags is None else tags
if self.tags is None:
Expand All @@ -100,97 +112,77 @@ def separate_init_kwargs(cls, kwargs):
# FUTURE: Don't pass Scene.save_datasets kwargs to init and here
init_kwargs, kwargs = super(GeoTIFFWriter, cls).separate_init_kwargs(
kwargs)
for kw in ['floating_point', 'tags']:
for kw in ['dtype', 'tags']:
if kw in kwargs:
init_kwargs[kw] = kwargs.pop(kw)

return init_kwargs, kwargs

def _gdal_write_datasets(self, dst_ds, datasets, opacity):
"""Write *datasets* in a gdal raster structure *dts_ds*, using
*opacity* as alpha value for valid data, and *fill_value*.
"""
def _write_array(bnd, chn):
bnd.WriteArray(chn.values)

# queue up data writes so we don't waste computation time
delayed = []
def _gdal_write_datasets(self, dst_ds, datasets):
"""Write datasets in a gdal raster structure dts_ds"""
for i, band in enumerate(datasets['bands']):
chn = datasets.sel(bands=band)
bnd = dst_ds.GetRasterBand(i + 1)
bnd.SetNoDataValue(0)
delay = dask.delayed(_write_array)(bnd, chn)
delayed.append(delay)
dask.compute(*delayed)

def _create_file(self, filename, img, gformat, g_opts, opacity,
datasets, mode):
raster = gdal.GetDriverByName("GTiff")

if mode == "L":
dst_ds = raster.Create(filename, img.width, img.height, 1,
gformat, g_opts)
self._gdal_write_datasets(dst_ds, datasets, opacity)
elif mode == "LA":
g_opts.append("ALPHA=YES")
dst_ds = raster.Create(filename, img.width, img.height, 2, gformat,
g_opts)
self._gdal_write_datasets(dst_ds, datasets, datasets)
elif mode == "RGB":
dst_ds = raster.Create(filename, img.width, img.height, 3,
gformat, g_opts)
self._gdal_write_datasets(dst_ds, datasets, datasets)

elif mode == "RGBA":
g_opts.append("ALPHA=YES")
dst_ds = raster.Create(filename, img.width, img.height, 4, gformat,
g_opts)

self._gdal_write_datasets(dst_ds, datasets, datasets)
else:
raise NotImplementedError(
"Saving to GeoTIFF using image mode %s is not implemented." %
mode)

# Create raster GeoTransform based on upper left corner and pixel
# resolution ... if not overwritten by argument geotransform.
if "area" not in img.data.attrs:
LOG.warning("No 'area' metadata found in image")
else:
area = img.data.attrs["area"]
bnd.WriteArray(chn.values)

def _gdal_write_geo(self, dst_ds, area):
try:
geotransform = [area.area_extent[0], area.pixel_size_x, 0,
area.area_extent[3], 0, -area.pixel_size_y]
dst_ds.SetGeoTransform(geotransform)
srs = osr.SpatialReference()

srs.ImportFromProj4(area.proj4_string)
srs.SetProjCS(area.proj_id)
try:
srs.SetWellKnownGeogCS(area.proj_dict['ellps'])
except KeyError:
pass
try:
geotransform = [area.area_extent[0], area.pixel_size_x, 0,
area.area_extent[3], 0, -area.pixel_size_y]
dst_ds.SetGeoTransform(geotransform)
srs = osr.SpatialReference()

srs.ImportFromProj4(area.proj4_string)
srs.SetProjCS(area.proj_id)
try:
srs.SetWellKnownGeogCS(area.proj_dict['ellps'])
except KeyError:
pass
try:
# Check for epsg code.
srs.ImportFromEPSG(int(
area.proj_dict['init'].lower().split('epsg:')[1]))
except (KeyError, IndexError):
pass
srs = srs.ExportToWkt()
dst_ds.SetProjection(srs)
except AttributeError:
LOG.warning(
"Can't save geographic information to geotiff, unsupported area type")

tags = self.tags.copy()
if "start_time" in img.data.attrs:
tags.update({'TIFFTAG_DATETIME': img.data.attrs["start_time"].strftime(
"%Y:%m:%d %H:%M:%S")})

dst_ds.SetMetadata(tags, '')

def save_image(self, img, filename=None, floating_point=None,
compute=True, **kwargs):
# Check for epsg code.
srs.ImportFromEPSG(int(
area.proj_dict['init'].lower().split('epsg:')[1]))
except (KeyError, IndexError):
pass
srs = srs.ExportToWkt()
dst_ds.SetProjection(srs)
except AttributeError:
LOG.warning(
"Can't save geographic information to geotiff, unsupported area type")

def _create_file(self, filename, img, gformat, g_opts, datasets, mode):
num_bands = len(mode)
if mode[-1] == 'A':
g_opts.append("ALPHA=YES")

def _delayed_create(create_opts, datasets, area, start_time, tags):
raster = gdal.GetDriverByName("GTiff")
dst_ds = raster.Create(*create_opts)
self._gdal_write_datasets(dst_ds, datasets)

# Create raster GeoTransform based on upper left corner and pixel
# resolution ... if not overwritten by argument geotransform.
if "area" is None:
LOG.warning("No 'area' metadata found in image")
else:
self._gdal_write_geo(dst_ds, area)

if start_time is not None:
tags.update({'TIFFTAG_DATETIME': start_time.strftime(
"%Y:%m:%d %H:%M:%S")})

dst_ds.SetMetadata(tags, '')

create_opts = (filename, img.width, img.height, num_bands, gformat, g_opts)
delayed = dask.delayed(_delayed_create)(
create_opts, datasets, img.data.attrs.get('area'),
img.data.attrs.get('start_time'),
self.tags.copy())
return delayed

def save_image(self, img, filename=None, dtype=None, fill_value=None,
floating_point=None, compute=True, **kwargs):
"""Save the image to the given *filename* in geotiff_ format.
`floating_point` allows the saving of
'L' mode images in floating point format if set to True.
Expand All @@ -205,43 +197,51 @@ def save_image(self, img, filename=None, floating_point=None,
if k in self.GDAL_OPTIONS:
gdal_options[k] = kwargs[k]

floating_point = floating_point if floating_point is not None else self.floating_point
if floating_point is not None:
import warnings
warnings.warn("'floating_point' is deprecated, use"
"'dtype=np.float64' instead.",
DeprecationWarning)
dtype = np.float64
dtype = dtype if dtype is not None else self.dtype
if dtype is None:
dtype = np.uint8

if "alpha" in kwargs:
raise ValueError(
"Keyword 'alpha' is automatically set and should not be specified")
if floating_point:
"Keyword 'alpha' is automatically set based on 'fill_value' "
"and should not be specified")
if np.issubdtype(dtype, np.floating):
if img.mode != "L":
raise ValueError("Image must be in 'L' mode for floating "
"point geotiff saving")
fill_value = np.nan
datasets, mode = img._finalize(fill_value=fill_value,
dtype=np.float64)
gformat = gdal.GDT_Float64
opacity = 0
else:
nbits = int(gdal_options.get("nbits", "8"))
if nbits > 16:
dtype = np.uint32
gformat = gdal.GDT_UInt32
elif nbits > 8:
dtype = np.uint16
gformat = gdal.GDT_UInt16
else:
dtype = np.uint8
gformat = gdal.GDT_Byte
opacity = np.iinfo(dtype).max
datasets, mode = img._finalize(dtype=dtype)

LOG.debug("Saving to GeoTiff: %s", filename)

g_opts = ["{0}={1}".format(k.upper(), str(v))
for k, v in gdal_options.items()]

ensure_dir(filename)
delayed = dask.delayed(self._create_file)(filename, img, gformat,
g_opts, opacity, datasets,
mode)
if compute:
return delayed.compute()
return delayed
if fill_value is None:
LOG.debug("Alpha band not supported for float geotiffs, "
"setting fill value to 'NaN'")
fill_value = np.nan

try:
import rasterio # noqa
# we can use the faster rasterio-based save
return img.save(filename, fformat='tif', fill_value=fill_value,
dtype=dtype, compute=compute, **gdal_options)
except ImportError:
LOG.warning("Using legacy/slower geotiff save method, install "
"'rasterio' for faster saving.")
# force to numpy dtype object
dtype = np.dtype(dtype)
gformat = NP2GDAL[dtype.type]

gdal_options['nbits'] = int(gdal_options.get('nbits',
dtype.itemsize * 8))
datasets, mode = img._finalize(fill_value=fill_value, dtype=dtype)
LOG.debug("Saving to GeoTiff: %s", filename)
g_opts = ["{0}={1}".format(k.upper(), str(v))
for k, v in gdal_options.items()]

ensure_dir(filename)
delayed = self._create_file(filename, img, gformat, g_opts,
datasets, mode)
if compute:
return delayed.compute()
return delayed

0 comments on commit 2d1b12c

Please sign in to comment.