Skip to content

Commit

Permalink
Automatically detect grid registration from xarray data source
Browse files Browse the repository at this point in the history
For xarray grids read from disk, there is an 'encoding' dictionary, with a 'source' key that gives the path to the file. Running `grdinfo` on that file can give us the grid registration. This works on the earth_relief grids. For grids that are not read from disk, we still default to assuming gridline registration ("GMT_GRID_NODE_REG").
  • Loading branch information
weiji14 committed Jun 23, 2020
1 parent 90b22a4 commit 6d4ece4
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
31 changes: 27 additions & 4 deletions pygmt/base_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use_alias,
kwargs_to_strings,
)
from .modules import grdinfo


class BasePlotting:
Expand Down Expand Up @@ -56,6 +57,22 @@ def _preprocess(self, **kwargs): # pylint: disable=no-self-use
"""
return kwargs

def autodetect_registration(self, grid):
"""
Function to automatically detect whether the NetCDF source of an
xarray.DataArray grid uses gridline or pixel registration. Defaults to
gridline registration if grdinfo cannot find a source file.
"""
registration = "GMT_GRID_NODE_REG" # default to gridline registration

try:
if "Pixel node registration used" in grdinfo(grid.encoding["source"]):
registration = "GMT_GRID_PIXEL_REG"
except KeyError:
pass

return registration

@fmt_docstring
@use_alias(
R="region",
Expand Down Expand Up @@ -282,7 +299,8 @@ def grdcontour(self, grid, **kwargs):
if kind == "file":
file_context = dummy_context(grid)
elif kind == "grid":
file_context = lib.virtualfile_from_grid(grid)
registration = self.autodetect_registration(grid)
file_context = lib.virtualfile_from_grid(grid, registration)
else:
raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid)))
with file_context as fname:
Expand Down Expand Up @@ -314,7 +332,8 @@ def grdimage(self, grid, **kwargs):
if kind == "file":
file_context = dummy_context(grid)
elif kind == "grid":
file_context = lib.virtualfile_from_grid(grid)
registration = self.autodetect_registration(grid)
file_context = lib.virtualfile_from_grid(grid, registration)
else:
raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid)))
with file_context as fname:
Expand Down Expand Up @@ -410,7 +429,8 @@ def grdview(self, grid, **kwargs):
if kind == "file":
file_context = dummy_context(grid)
elif kind == "grid":
file_context = lib.virtualfile_from_grid(grid)
registration = self.autodetect_registration(grid)
file_context = lib.virtualfile_from_grid(grid, registration)
else:
raise GMTInvalidInput(f"Unrecognized data type for grid: {type(grid)}")

Expand All @@ -420,7 +440,10 @@ def grdview(self, grid, **kwargs):
drapegrid = kwargs["G"]
if data_kind(drapegrid) in ("file", "grid"):
if data_kind(drapegrid) == "grid":
drape_context = lib.virtualfile_from_grid(drapegrid)
registration = self.autodetect_registration(grid)
drape_context = lib.virtualfile_from_grid(
drapegrid, registration
)
drapefile = stack.enter_context(drape_context)
kwargs["G"] = drapefile
else:
Expand Down
10 changes: 8 additions & 2 deletions pygmt/clib/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..exceptions import GMTInvalidInput


def dataarray_to_matrix(grid):
def dataarray_to_matrix(grid, registration="GMT_GRID_NODE_REG"):
"""
Transform an xarray.DataArray into a data 2D array and metadata.
Expand All @@ -27,6 +27,9 @@ def dataarray_to_matrix(grid):
grid : xarray.DataArray
The input grid as a DataArray instance. Information is retrieved from
the coordinate arrays, not from headers.
registration : str
Either one of "GMT_GRID_PIXEL_REG" for pixel registration, or
"GMT_GRID_NODE_REG" for gridline registration [Default].
Returns
-------
Expand Down Expand Up @@ -102,7 +105,10 @@ def dataarray_to_matrix(grid):
dim
)
)
region.extend([coord.min(), coord.max()])
if registration == "GMT_GRID_PIXEL_REG":
region.extend([coord.min() - coord_inc / 2, coord.max() + coord_inc / 2])
elif registration == "GMT_GRID_NODE_REG":
region.extend([coord.min(), coord.max()])
inc.append(coord_inc)

if any([i < 0 for i in inc]): # Sort grid when there are negative increments
Expand Down
2 changes: 1 addition & 1 deletion pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,7 @@ def virtualfile_from_grid(self, grid, registration="GMT_GRID_NODE_REG"):
# collected and the memory freed. Creating it in this context manager
# guarantees that the copy will be around until the virtual file is
# closed. The conversion is implicit in dataarray_to_matrix.
matrix, region, inc = dataarray_to_matrix(grid)
matrix, region, inc = dataarray_to_matrix(grid, registration)
family = "GMT_IS_GRID|GMT_VIA_MATRIX"
geometry = "GMT_IS_SURFACE"
gmt_grid = self.create_data(
Expand Down

0 comments on commit 6d4ece4

Please sign in to comment.