Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow multiindex levels in plots #3938

Merged
merged 18 commits into from
May 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ labels can also be used to easily create informative plots.
xarray's plotting capabilities are centered around
:py:class:`DataArray` objects.
To plot :py:class:`Dataset` objects
simply access the relevant DataArrays, ie ``dset['var1']``.
simply access the relevant DataArrays, i.e. ``dset['var1']``.
Dataset specific plotting routines are also available (see :ref:`plot-dataset`).
Here we focus mostly on arrays 2d or larger. If your data fits
nicely into a pandas DataFrame then you're better off using one of the more
Expand Down Expand Up @@ -209,6 +209,44 @@ entire figure (as for matplotlib's ``figsize`` argument).

.. _plotting.multiplelines:

=========================
Determine x-axis values
=========================

Per default dimension coordinates are used for the x-axis (here the time coordinates).
However, you can also use non-dimension coordinates, MultiIndex levels, and dimensions
without coordinates along the x-axis. To illustrate this, let's calculate a 'decimal day' (epoch)
from the time and assign it as a non-dimension coordinate:

.. ipython:: python

decimal_day = (air1d.time - air1d.time[0]) / pd.Timedelta('1d')
air1d_multi = air1d.assign_coords(decimal_day=("time", decimal_day))
air1d_multi

To use ``'decimal_day'`` as x coordinate it must be explicitly specified:

.. ipython:: python

air1d_multi.plot(x="decimal_day")

Creating a new MultiIndex named ``'date'`` from ``'time'`` and ``'decimal_day'``,
it is also possible to use a MultiIndex level as x-axis:

.. ipython:: python

air1d_multi = air1d_multi.set_index(date=("time", "decimal_day"))
air1d_multi.plot(x="decimal_day")

Finally, if a dataset does not have any coordinates it enumerates all data points:

.. ipython:: python

air1d_multi = air1d_multi.drop("date")
air1d_multi.plot()

The same applies to 2D plots below.

====================================================
Multiple lines showing variation along a dimension
====================================================
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ New Features
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Allow plotting of boolean arrays. (:pull:`3766`)
By `Marek Jacob <https://github.com/MeraX>`_
- Enable using MultiIndex levels as cordinates in 1D and 2D plots (:issue:`3927`).
By `Mathias Hauser <https://github.com/mathause>`_.
- A ``days_in_month`` accessor for :py:class:`xarray.CFTimeIndex`, analogous to
the ``days_in_month`` accessor for a :py:class:`pandas.DatetimeIndex`, which
returns the days in the month each datetime in the index. Now days in month
Expand Down
19 changes: 9 additions & 10 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .facetgrid import _easy_facetgrid
from .utils import (
_add_colorbar,
_assert_valid_xy,
_ensure_plottable,
_infer_interval_breaks,
_infer_xy_labels,
Expand All @@ -29,19 +30,17 @@


def _infer_line_data(darray, x, y, hue):
error_msg = "must be either None or one of ({:s})".format(
", ".join(repr(dd) for dd in darray.dims)
)

ndims = len(darray.dims)

if x is not None and x not in darray.dims and x not in darray.coords:
raise ValueError("x " + error_msg)
if x is not None and y is not None:
raise ValueError("Cannot specify both x and y kwargs for line plots.")

if y is not None and y not in darray.dims and y not in darray.coords:
raise ValueError("y " + error_msg)
if x is not None:
_assert_valid_xy(darray, x, "x")

if x is not None and y is not None:
raise ValueError("You cannot specify both x and y kwargs" "for line plots.")
if y is not None:
_assert_valid_xy(darray, y, "y")

if ndims == 1:
huename = None
Expand Down Expand Up @@ -252,7 +251,7 @@ def line(
Dimension or coordinate for which you want multiple lines plotted.
If plotting against a 2D coordinate, ``hue`` must be a dimension.
x, y : string, optional
Dimensions or coordinates for x, y axis.
Dimension, coordinate or MultiIndex level for x, y axis.
Only one of these may be specified.
The other coordinate plots values from the DataArray on which this
plot method is called.
Expand Down
39 changes: 32 additions & 7 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,9 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None):

darray must be a 2 dimensional data array, or 3d for imshow only.
"""
assert x is None or x != y
if (x is not None) and (x == y):
raise ValueError("x and y cannot be equal.")

if imshow and darray.ndim == 3:
return _infer_xy_labels_3d(darray, x, y, rgb)

Expand All @@ -369,18 +371,41 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None):
raise ValueError("DataArray must be 2d")
y, x = darray.dims
elif x is None:
if y not in darray.dims and y not in darray.coords:
raise ValueError("y must be a dimension name if x is not supplied")
_assert_valid_xy(darray, y, "y")
x = darray.dims[0] if y == darray.dims[1] else darray.dims[1]
elif y is None:
if x not in darray.dims and x not in darray.coords:
raise ValueError("x must be a dimension name if y is not supplied")
_assert_valid_xy(darray, x, "x")
y = darray.dims[0] if x == darray.dims[1] else darray.dims[1]
elif any(k not in darray.coords and k not in darray.dims for k in (x, y)):
raise ValueError("x and y must be coordinate variables")
else:
_assert_valid_xy(darray, x, "x")
_assert_valid_xy(darray, y, "y")

if (
all(k in darray._level_coords for k in (x, y))
and darray._level_coords[x] == darray._level_coords[y]
):
raise ValueError("x and y cannot be levels of the same MultiIndex")

return x, y


def _assert_valid_xy(darray, xy, name):
"""
make sure x and y passed to plotting functions are valid
"""

# MultiIndex cannot be plotted; no point in allowing them here
multiindex = set([darray._level_coords[lc] for lc in darray._level_coords])

valid_xy = (
set(darray.dims) | set(darray.coords) | set(darray._level_coords)
) - multiindex

if xy not in valid_xy:
valid_xy_str = "', '".join(sorted(valid_xy))
raise ValueError(f"{name} must be one of None, '{valid_xy_str}'")


def get_axis(figsize, size, aspect, ax):
import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down
77 changes: 63 additions & 14 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_label_from_attrs(self):
def test1d(self):
self.darray[:, 0, 0].plot()

with raises_regex(ValueError, "None"):
with raises_regex(ValueError, "x must be one of None, 'dim_0'"):
self.darray[:, 0, 0].plot(x="dim_1")

with raises_regex(TypeError, "complex128"):
Expand All @@ -155,14 +155,31 @@ def test_1d_x_y_kw(self):
for aa, (x, y) in enumerate(xy):
da.plot(x=x, y=y, ax=ax.flat[aa])

with raises_regex(ValueError, "cannot"):
with raises_regex(ValueError, "Cannot specify both"):
da.plot(x="z", y="z")

with raises_regex(ValueError, "None"):
da.plot(x="f", y="z")
error_msg = "must be one of None, 'z'"
with raises_regex(ValueError, f"x {error_msg}"):
da.plot(x="f")

with raises_regex(ValueError, "None"):
da.plot(x="z", y="f")
with raises_regex(ValueError, f"y {error_msg}"):
da.plot(y="f")

def test_multiindex_level_as_coord(self):
da = xr.DataArray(
np.arange(5),
dims="x",
coords=dict(a=("x", np.arange(5)), b=("x", np.arange(5, 10))),
)
da = da.set_index(x=["a", "b"])

for x in ["a", "b"]:
h = da.plot(x=x)[0]
assert_array_equal(h.get_xdata(), da[x].values)

for y in ["a", "b"]:
h = da.plot(y=y)[0]
assert_array_equal(h.get_ydata(), da[y].values)

# Test for bug in GH issue #2725
def test_infer_line_data(self):
Expand Down Expand Up @@ -211,7 +228,7 @@ def test_2d_line(self):
self.darray[:, :, 0].plot.line(x="dim_0", hue="dim_1")
self.darray[:, :, 0].plot.line(y="dim_0", hue="dim_1")

with raises_regex(ValueError, "cannot"):
with raises_regex(ValueError, "Cannot"):
self.darray[:, :, 0].plot.line(x="dim_1", y="dim_0", hue="dim_1")

def test_2d_line_accepts_legend_kw(self):
Expand Down Expand Up @@ -1032,6 +1049,16 @@ def test_nonnumeric_index_raises_typeerror(self):
with raises_regex(TypeError, r"[Pp]lot"):
self.plotfunc(a)

def test_multiindex_raises_typeerror(self):
a = DataArray(
easy_array((3, 2)),
dims=("x", "y"),
coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])),
)
a = a.set_index(y=("a", "b"))
with raises_regex(TypeError, r"[Pp]lot"):
self.plotfunc(a)

def test_can_pass_in_axis(self):
self.pass_in_axis(self.plotmethod)

Expand Down Expand Up @@ -1140,15 +1167,16 @@ def test_positional_coord_string(self):
assert "y_long_name [y_units]" == ax.get_ylabel()

def test_bad_x_string_exception(self):
with raises_regex(ValueError, "x and y must be coordinate variables"):

with raises_regex(ValueError, "x and y cannot be equal."):
self.plotmethod(x="y", y="y")

error_msg = "must be one of None, 'x', 'x2d', 'y', 'y2d'"
with raises_regex(ValueError, f"x {error_msg}"):
self.plotmethod("not_a_real_dim", "y")
with raises_regex(
ValueError, "x must be a dimension name if y is not supplied"
):
with raises_regex(ValueError, f"x {error_msg}"):
self.plotmethod(x="not_a_real_dim")
with raises_regex(
ValueError, "y must be a dimension name if x is not supplied"
):
with raises_regex(ValueError, f"y {error_msg}"):
self.plotmethod(y="not_a_real_dim")
self.darray.coords["z"] = 100

Expand Down Expand Up @@ -1183,6 +1211,27 @@ def test_non_linked_coords_transpose(self):
# simply ensure that these high coords were passed over
assert np.min(ax.get_xlim()) > 100.0

def test_multiindex_level_as_coord(self):
da = DataArray(
easy_array((3, 2)),
dims=("x", "y"),
coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])),
)
da = da.set_index(y=["a", "b"])

for x, y in (("a", "x"), ("b", "x"), ("x", "a"), ("x", "b")):
self.plotfunc(da, x=x, y=y)

ax = plt.gca()
assert x == ax.get_xlabel()
assert y == ax.get_ylabel()

with raises_regex(ValueError, "levels of the same MultiIndex"):
self.plotfunc(da, x="a", y="b")

with raises_regex(ValueError, "y must be one of None, 'a', 'b', 'x'"):
self.plotfunc(da, x="a", y="y")

def test_default_title(self):
a = DataArray(easy_array((4, 3, 2)), dims=["a", "b", "c"])
a.coords["c"] = [0, 1]
Expand Down