From ccbb84de60f1e03fd4b3374760668ec49e287910 Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 16 Mar 2023 19:55:30 +0100 Subject: [PATCH] Fix `pcolormesh` with str coords (#7612) * pcolormesh with str coords * add whats-new --- doc/whats-new.rst | 2 ++ xarray/plot/dataarray_plot.py | 24 ++++++++++-------------- xarray/tests/test_plot.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1907a916dbc..cdbb3335372 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -59,6 +59,8 @@ Bug fixes By `Jimmy Westling `_. - Improved performance in ``open_dataset`` for datasets with large object arrays (:issue:`7484`, :pull:`7494`). By `Alex Goodman `_ and `Deepak Cherian `_. +- Fix :py:meth:`DataArray.plot.pcolormesh` which now works if one of the coordinates has str dtype (:issue:`6775`, :pull:`7612`). + By `Michael Niklas `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 4c77539b5bb..a80db91562c 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -2293,27 +2293,23 @@ def pcolormesh( else: infer_intervals = True - if ( - infer_intervals - and not np.issubdtype(x.dtype, str) - and ( - (np.shape(x)[0] == np.shape(z)[1]) - or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) - ) + if any(np.issubdtype(k.dtype, str) for k in (x, y)): + # do not infer intervals if any axis contains str ticks, see #6775 + infer_intervals = False + + if infer_intervals and ( + (np.shape(x)[0] == np.shape(z)[1]) + or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) ): - if len(x.shape) == 1: + if x.ndim == 1: x = _infer_interval_breaks(x, check_monotonic=True, scale=xscale) else: # we have to infer the intervals on both axes x = _infer_interval_breaks(x, axis=1, scale=xscale) x = _infer_interval_breaks(x, axis=0, scale=xscale) - if ( - infer_intervals - and not np.issubdtype(y.dtype, str) - and (np.shape(y)[0] == np.shape(z)[0]) - ): - if len(y.shape) == 1: + if infer_intervals and (np.shape(y)[0] == np.shape(z)[0]): + if y.ndim == 1: y = _infer_interval_breaks(y, check_monotonic=True, scale=yscale) else: # we have to infer the intervals on both axes diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index f736a851d7e..b7b5f005f0c 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -427,6 +427,16 @@ def test2d_1d_2d_coordinates_pcolormesh(self) -> None: _, unique_counts = np.unique(v[:-1], axis=0, return_counts=True) assert np.all(unique_counts == 1) + def test_str_coordinates_pcolormesh(self) -> None: + # test for #6775 + x = DataArray( + [[1, 2, 3], [4, 5, 6]], + dims=("a", "b"), + coords={"a": [1, 2], "b": ["a", "b", "c"]}, + ) + x.plot.pcolormesh() + x.T.plot.pcolormesh() + def test_contourf_cmap_set(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"])