Skip to content

Commit

Permalink
Fix pcolormesh with str coords (#7612)
Browse files Browse the repository at this point in the history
* pcolormesh with str coords

* add whats-new
  • Loading branch information
headtr1ck authored Mar 16, 2023
1 parent e7b4930 commit ccbb84d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ Bug fixes
By `Jimmy Westling <https://github.com/illviljan>`_.
- Improved performance in ``open_dataset`` for datasets with large object arrays (:issue:`7484`, :pull:`7494`).
By `Alex Goodman <https://github.com/agoodm>`_ and `Deepak Cherian <https://github.com/dcherian>`_.
- Fix :py:meth:`DataArray.plot.pcolormesh` which now works if one of the coordinates has str dtype (:issue:`6775`, :pull:`7612`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
24 changes: 10 additions & 14 deletions xarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2293,27 +2293,23 @@ def pcolormesh(
else:
infer_intervals = True

if (
infer_intervals
and not np.issubdtype(x.dtype, str)
and (
(np.shape(x)[0] == np.shape(z)[1])
or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1]))
)
if any(np.issubdtype(k.dtype, str) for k in (x, y)):
# do not infer intervals if any axis contains str ticks, see #6775
infer_intervals = False

if infer_intervals and (
(np.shape(x)[0] == np.shape(z)[1])
or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1]))
):
if len(x.shape) == 1:
if x.ndim == 1:
x = _infer_interval_breaks(x, check_monotonic=True, scale=xscale)
else:
# we have to infer the intervals on both axes
x = _infer_interval_breaks(x, axis=1, scale=xscale)
x = _infer_interval_breaks(x, axis=0, scale=xscale)

if (
infer_intervals
and not np.issubdtype(y.dtype, str)
and (np.shape(y)[0] == np.shape(z)[0])
):
if len(y.shape) == 1:
if infer_intervals and (np.shape(y)[0] == np.shape(z)[0]):
if y.ndim == 1:
y = _infer_interval_breaks(y, check_monotonic=True, scale=yscale)
else:
# we have to infer the intervals on both axes
Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,16 @@ def test2d_1d_2d_coordinates_pcolormesh(self) -> None:
_, unique_counts = np.unique(v[:-1], axis=0, return_counts=True)
assert np.all(unique_counts == 1)

def test_str_coordinates_pcolormesh(self) -> None:
# test for #6775
x = DataArray(
[[1, 2, 3], [4, 5, 6]],
dims=("a", "b"),
coords={"a": [1, 2], "b": ["a", "b", "c"]},
)
x.plot.pcolormesh()
x.T.plot.pcolormesh()

def test_contourf_cmap_set(self) -> None:
a = DataArray(easy_array((4, 4)), dims=["z", "time"])

Expand Down

0 comments on commit ccbb84d

Please sign in to comment.