Skip to content

Commit

Permalink
Use broadcast_like for 2d plot coordinates (#5099)
Browse files Browse the repository at this point in the history
* Use broadcast_like for 2d plot coordinates

Use broadcast_like if either `x` or `y` inputs are 2d to ensure that
both have dimensions in the same order as the DataArray being plotted.
Convert to numpy arrays after possibly using broadcast_like. Simplifies
code, and fixes #5097 (bug when dimensions have the same size).

* Update whats-new

* Test for issue 5097

* Fix typo in doc/whats-new.rst

Co-authored-by: Mathias Hauser <[email protected]>

* Update doc/whats-new.rst

Co-authored-by: Mathias Hauser <[email protected]>
  • Loading branch information
johnomotani and mathause authored Apr 22, 2021
1 parent d58a511 commit b2351cb
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 21 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ Deprecations

Bug fixes
~~~~~~~~~
- Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is
2d (:issue:`5097`, :pull:`5099`). By `John Omotani <https://github.com/johnomotani>`_.
- Ensure standard calendar times encoded with large values (i.e. greater than approximately 292 years), can be decoded correctly without silently overflowing (:pull:`5050`). This was a regression in xarray 0.17.0. By `Zeb Nicholls <https://github.com/znicholls>`_.
- Added support for `numpy.bool_` attributes in roundtrips using `h5netcdf` engine with `invalid_netcdf=True` [which casts `bool`s to `numpy.bool_`] (:issue:`4981`, :pull:`4986`).
By `Victor Negîrneac <https://github.com/caenrigen>`_.
Expand Down
35 changes: 14 additions & 21 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,28 +671,21 @@ def newplotfunc(
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb
)

# better to pass the ndarrays directly to plotting functions
xval = darray[xlab].values
yval = darray[ylab].values

# check if we need to broadcast one dimension
if xval.ndim < yval.ndim:
dims = darray[ylab].dims
if xval.shape[0] == yval.shape[0]:
xval = np.broadcast_to(xval[:, np.newaxis], yval.shape)
else:
xval = np.broadcast_to(xval[np.newaxis, :], yval.shape)

elif yval.ndim < xval.ndim:
dims = darray[xlab].dims
if yval.shape[0] == xval.shape[0]:
yval = np.broadcast_to(yval[:, np.newaxis], xval.shape)
else:
yval = np.broadcast_to(yval[np.newaxis, :], xval.shape)
elif xval.ndim == 2:
dims = darray[xlab].dims
xval = darray[xlab]
yval = darray[ylab]

if xval.ndim > 1 or yval.ndim > 1:
# Passing 2d coordinate values, need to ensure they are transposed the same
# way as darray
xval = xval.broadcast_like(darray)
yval = yval.broadcast_like(darray)
dims = darray.dims
else:
dims = (darray[ylab].dims[0], darray[xlab].dims[0])
dims = (yval.dims[0], xval.dims[0])

# better to pass the ndarrays directly to plotting functions
xval = xval.values
yval = yval.values

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
Expand Down
28 changes: 28 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,34 @@ def test2d_1d_2d_coordinates_contourf(self):
a.plot.contourf(x="time", y="depth")
a.plot.contourf(x="depth", y="time")

def test2d_1d_2d_coordinates_pcolormesh(self):
# Test with equal coordinates to catch bug from #5097
sz = 10
y2d, x2d = np.meshgrid(np.arange(sz), np.arange(sz))
a = DataArray(
easy_array((sz, sz)),
dims=["x", "y"],
coords={"x2d": (["x", "y"], x2d), "y2d": (["x", "y"], y2d)},
)

for x, y in [
("x", "y"),
("y", "x"),
("x2d", "y"),
("y", "x2d"),
("x", "y2d"),
("y2d", "x"),
("x2d", "y2d"),
("y2d", "x2d"),
]:
p = a.plot.pcolormesh(x=x, y=y)
v = p.get_paths()[0].vertices

# Check all vertices are different, except last vertex which should be the
# same as the first
_, unique_counts = np.unique(v[:-1], axis=0, return_counts=True)
assert np.all(unique_counts == 1)

def test_contourf_cmap_set(self):
a = DataArray(easy_array((4, 4)), dims=["z", "time"])

Expand Down

0 comments on commit b2351cb

Please sign in to comment.