diff --git a/doc/plotting.rst b/doc/plotting.rst index 40c0ca1a496..f98f47f2567 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -13,7 +13,7 @@ labels can also be used to easily create informative plots. xarray's plotting capabilities are centered around :py:class:`DataArray` objects. To plot :py:class:`Dataset` objects -simply access the relevant DataArrays, ie ``dset['var1']``. +simply access the relevant DataArrays, i.e. ``dset['var1']``. Dataset specific plotting routines are also available (see :ref:`plot-dataset`). Here we focus mostly on arrays 2d or larger. If your data fits nicely into a pandas DataFrame then you're better off using one of the more @@ -209,6 +209,44 @@ entire figure (as for matplotlib's ``figsize`` argument). .. _plotting.multiplelines: +========================= + Determine x-axis values +========================= + +Per default dimension coordinates are used for the x-axis (here the time coordinates). +However, you can also use non-dimension coordinates, MultiIndex levels, and dimensions +without coordinates along the x-axis. To illustrate this, let's calculate a 'decimal day' (epoch) +from the time and assign it as a non-dimension coordinate: + +.. ipython:: python + + decimal_day = (air1d.time - air1d.time[0]) / pd.Timedelta('1d') + air1d_multi = air1d.assign_coords(decimal_day=("time", decimal_day)) + air1d_multi + +To use ``'decimal_day'`` as x coordinate it must be explicitly specified: + +.. ipython:: python + + air1d_multi.plot(x="decimal_day") + +Creating a new MultiIndex named ``'date'`` from ``'time'`` and ``'decimal_day'``, +it is also possible to use a MultiIndex level as x-axis: + +.. ipython:: python + + air1d_multi = air1d_multi.set_index(date=("time", "decimal_day")) + air1d_multi.plot(x="decimal_day") + +Finally, if a dataset does not have any coordinates it enumerates all data points: + +.. ipython:: python + + air1d_multi = air1d_multi.drop("date") + air1d_multi.plot() + +The same applies to 2D plots below. + ==================================================== Multiple lines showing variation along a dimension ==================================================== diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a4602c1edad..0be988da690 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -63,6 +63,8 @@ New Features By `Stephan Hoyer `_. - Allow plotting of boolean arrays. (:pull:`3766`) By `Marek Jacob `_ +- Enable using MultiIndex levels as cordinates in 1D and 2D plots (:issue:`3927`). + By `Mathias Hauser `_. - A ``days_in_month`` accessor for :py:class:`xarray.CFTimeIndex`, analogous to the ``days_in_month`` accessor for a :py:class:`pandas.DatetimeIndex`, which returns the days in the month each datetime in the index. Now days in month diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 4d6033bf00d..19a3f1e63e3 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -14,6 +14,7 @@ from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, + _assert_valid_xy, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, @@ -29,19 +30,17 @@ def _infer_line_data(darray, x, y, hue): - error_msg = "must be either None or one of ({:s})".format( - ", ".join(repr(dd) for dd in darray.dims) - ) + ndims = len(darray.dims) - if x is not None and x not in darray.dims and x not in darray.coords: - raise ValueError("x " + error_msg) + if x is not None and y is not None: + raise ValueError("Cannot specify both x and y kwargs for line plots.") - if y is not None and y not in darray.dims and y not in darray.coords: - raise ValueError("y " + error_msg) + if x is not None: + _assert_valid_xy(darray, x, "x") - if x is not None and y is not None: - raise ValueError("You cannot specify both x and y kwargs" "for line plots.") + if y is not None: + _assert_valid_xy(darray, y, "y") if ndims == 1: huename = None @@ -252,7 +251,7 @@ def line( Dimension or coordinate for which you want multiple lines plotted. If plotting against a 2D coordinate, ``hue`` must be a dimension. x, y : string, optional - Dimensions or coordinates for x, y axis. + Dimension, coordinate or MultiIndex level for x, y axis. Only one of these may be specified. The other coordinate plots values from the DataArray on which this plot method is called. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index cb993c192d9..e5c1fa89333 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -360,7 +360,9 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): darray must be a 2 dimensional data array, or 3d for imshow only. """ - assert x is None or x != y + if (x is not None) and (x == y): + raise ValueError("x and y cannot be equal.") + if imshow and darray.ndim == 3: return _infer_xy_labels_3d(darray, x, y, rgb) @@ -369,18 +371,41 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): raise ValueError("DataArray must be 2d") y, x = darray.dims elif x is None: - if y not in darray.dims and y not in darray.coords: - raise ValueError("y must be a dimension name if x is not supplied") + _assert_valid_xy(darray, y, "y") x = darray.dims[0] if y == darray.dims[1] else darray.dims[1] elif y is None: - if x not in darray.dims and x not in darray.coords: - raise ValueError("x must be a dimension name if y is not supplied") + _assert_valid_xy(darray, x, "x") y = darray.dims[0] if x == darray.dims[1] else darray.dims[1] - elif any(k not in darray.coords and k not in darray.dims for k in (x, y)): - raise ValueError("x and y must be coordinate variables") + else: + _assert_valid_xy(darray, x, "x") + _assert_valid_xy(darray, y, "y") + + if ( + all(k in darray._level_coords for k in (x, y)) + and darray._level_coords[x] == darray._level_coords[y] + ): + raise ValueError("x and y cannot be levels of the same MultiIndex") + return x, y +def _assert_valid_xy(darray, xy, name): + """ + make sure x and y passed to plotting functions are valid + """ + + # MultiIndex cannot be plotted; no point in allowing them here + multiindex = set([darray._level_coords[lc] for lc in darray._level_coords]) + + valid_xy = ( + set(darray.dims) | set(darray.coords) | set(darray._level_coords) + ) - multiindex + + if xy not in valid_xy: + valid_xy_str = "', '".join(sorted(valid_xy)) + raise ValueError(f"{name} must be one of None, '{valid_xy_str}'") + + def get_axis(figsize, size, aspect, ax): import matplotlib as mpl import matplotlib.pyplot as plt diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index af7c686bf60..6497987e813 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -136,7 +136,7 @@ def test_label_from_attrs(self): def test1d(self): self.darray[:, 0, 0].plot() - with raises_regex(ValueError, "None"): + with raises_regex(ValueError, "x must be one of None, 'dim_0'"): self.darray[:, 0, 0].plot(x="dim_1") with raises_regex(TypeError, "complex128"): @@ -155,14 +155,31 @@ def test_1d_x_y_kw(self): for aa, (x, y) in enumerate(xy): da.plot(x=x, y=y, ax=ax.flat[aa]) - with raises_regex(ValueError, "cannot"): + with raises_regex(ValueError, "Cannot specify both"): da.plot(x="z", y="z") - with raises_regex(ValueError, "None"): - da.plot(x="f", y="z") + error_msg = "must be one of None, 'z'" + with raises_regex(ValueError, f"x {error_msg}"): + da.plot(x="f") - with raises_regex(ValueError, "None"): - da.plot(x="z", y="f") + with raises_regex(ValueError, f"y {error_msg}"): + da.plot(y="f") + + def test_multiindex_level_as_coord(self): + da = xr.DataArray( + np.arange(5), + dims="x", + coords=dict(a=("x", np.arange(5)), b=("x", np.arange(5, 10))), + ) + da = da.set_index(x=["a", "b"]) + + for x in ["a", "b"]: + h = da.plot(x=x)[0] + assert_array_equal(h.get_xdata(), da[x].values) + + for y in ["a", "b"]: + h = da.plot(y=y)[0] + assert_array_equal(h.get_ydata(), da[y].values) # Test for bug in GH issue #2725 def test_infer_line_data(self): @@ -211,7 +228,7 @@ def test_2d_line(self): self.darray[:, :, 0].plot.line(x="dim_0", hue="dim_1") self.darray[:, :, 0].plot.line(y="dim_0", hue="dim_1") - with raises_regex(ValueError, "cannot"): + with raises_regex(ValueError, "Cannot"): self.darray[:, :, 0].plot.line(x="dim_1", y="dim_0", hue="dim_1") def test_2d_line_accepts_legend_kw(self): @@ -1032,6 +1049,16 @@ def test_nonnumeric_index_raises_typeerror(self): with raises_regex(TypeError, r"[Pp]lot"): self.plotfunc(a) + def test_multiindex_raises_typeerror(self): + a = DataArray( + easy_array((3, 2)), + dims=("x", "y"), + coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])), + ) + a = a.set_index(y=("a", "b")) + with raises_regex(TypeError, r"[Pp]lot"): + self.plotfunc(a) + def test_can_pass_in_axis(self): self.pass_in_axis(self.plotmethod) @@ -1140,15 +1167,16 @@ def test_positional_coord_string(self): assert "y_long_name [y_units]" == ax.get_ylabel() def test_bad_x_string_exception(self): - with raises_regex(ValueError, "x and y must be coordinate variables"): + + with raises_regex(ValueError, "x and y cannot be equal."): + self.plotmethod(x="y", y="y") + + error_msg = "must be one of None, 'x', 'x2d', 'y', 'y2d'" + with raises_regex(ValueError, f"x {error_msg}"): self.plotmethod("not_a_real_dim", "y") - with raises_regex( - ValueError, "x must be a dimension name if y is not supplied" - ): + with raises_regex(ValueError, f"x {error_msg}"): self.plotmethod(x="not_a_real_dim") - with raises_regex( - ValueError, "y must be a dimension name if x is not supplied" - ): + with raises_regex(ValueError, f"y {error_msg}"): self.plotmethod(y="not_a_real_dim") self.darray.coords["z"] = 100 @@ -1183,6 +1211,27 @@ def test_non_linked_coords_transpose(self): # simply ensure that these high coords were passed over assert np.min(ax.get_xlim()) > 100.0 + def test_multiindex_level_as_coord(self): + da = DataArray( + easy_array((3, 2)), + dims=("x", "y"), + coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])), + ) + da = da.set_index(y=["a", "b"]) + + for x, y in (("a", "x"), ("b", "x"), ("x", "a"), ("x", "b")): + self.plotfunc(da, x=x, y=y) + + ax = plt.gca() + assert x == ax.get_xlabel() + assert y == ax.get_ylabel() + + with raises_regex(ValueError, "levels of the same MultiIndex"): + self.plotfunc(da, x="a", y="b") + + with raises_regex(ValueError, "y must be one of None, 'a', 'b', 'x'"): + self.plotfunc(da, x="a", y="y") + def test_default_title(self): a = DataArray(easy_array((4, 3, 2)), dims=["a", "b", "c"]) a.coords["c"] = [0, 1]