diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 9d2481eed3c..3f936506234 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -597,6 +597,7 @@ plot.imshow plot.pcolormesh plot.scatter + plot.surface plot.FacetGrid.map_dataarray plot.FacetGrid.set_titles diff --git a/doc/api.rst b/doc/api.rst index 85a0d75f56a..da78307a349 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -588,6 +588,7 @@ Plotting DataArray.plot.line DataArray.plot.pcolormesh DataArray.plot.step + DataArray.plot.surface .. _api.ufuncs: diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index 098c63d0e40..f1c76b21488 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -411,6 +411,37 @@ produce plots with nonuniform coordinates. @savefig plotting_nonuniform_coords.png width=4in b.plot() +==================== + Other types of plot +==================== + +There are several other options for plotting 2D data. + +Contour plot using :py:meth:`DataArray.plot.contour()` + +.. ipython:: python + :okwarning: + + @savefig plotting_contour.png width=4in + air2d.plot.contour() + +Filled contour plot using :py:meth:`DataArray.plot.contourf()` + +.. ipython:: python + :okwarning: + + @savefig plotting_contourf.png width=4in + air2d.plot.contourf() + +Surface plot using :py:meth:`DataArray.plot.surface()` + +.. ipython:: python + :okwarning: + + @savefig plotting_surface.png width=4in + # transpose just to make the example look a bit nicer + air2d.T.plot.surface() + ==================== Calling Matplotlib ==================== diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2cffe076792..0081d18efb3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,8 @@ v0.17.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add :py:meth:`DataArray.plot.surface` which wraps matplotlib's `plot_surface` to make + surface plots (:issue:`#2235` :issue:`#5084` :pull:`5101`). - Allow passing multiple arrays to :py:meth:`Dataset.__setitem__` (:pull:`5216`). By `Giacomo Caria `_. - Add 'cumulative' option to :py:meth:`Dataset.integrate` and diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index 86a09506824..28ae0cf32e7 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -1,6 +1,6 @@ from .dataset_plot import scatter from .facetgrid import FacetGrid -from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step +from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step, surface __all__ = [ "plot", @@ -13,4 +13,5 @@ "pcolormesh", "FacetGrid", "scatter", + "surface", ] diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 8a52129ecf8..ab6d524aee4 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -263,7 +263,9 @@ def map_dataarray(self, func, x, y, **kwargs): if k not in {"cmap", "colors", "cbar_kwargs", "levels"} } func_kwargs.update(cmap_params) - func_kwargs.update({"add_colorbar": False, "add_labels": False}) + func_kwargs["add_colorbar"] = False + if func.__name__ != "surface": + func_kwargs["add_labels"] = False # Get x, y labels for the first subplot x, y = _infer_xy_labels( diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index b22a7017934..e6eb7ecbe0b 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -633,7 +633,11 @@ def newplotfunc( # Decide on a default for the colorbar before facetgrids if add_colorbar is None: - add_colorbar = plotfunc.__name__ != "contour" + add_colorbar = True + if plotfunc.__name__ == "contour" or ( + plotfunc.__name__ == "surface" and cmap is None + ): + add_colorbar = False imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( 3 + (row is not None) + (col is not None) ) @@ -646,6 +650,25 @@ def newplotfunc( darray = _rescale_imshow_rgb(darray, vmin, vmax, robust) vmin, vmax, robust = None, None, False + if subplot_kws is None: + subplot_kws = dict() + + if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): + if ax is None: + # TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2. + # Remove when minimum requirement of matplotlib is 3.2: + from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401 + + # delete so it does not end up in locals() + del Axes3D + + # Need to create a "3d" Axes instance for surface plots + subplot_kws["projection"] = "3d" + + # In facet grids, shared axis labels don't make sense for surface plots + sharex = False + sharey = False + # Handle facetgrids first if row or col: allargs = locals().copy() @@ -658,6 +681,19 @@ def newplotfunc( plt = import_matplotlib_pyplot() + if ( + plotfunc.__name__ == "surface" + and not kwargs.get("_is_facetgrid", False) + and ax is not None + ): + import mpl_toolkits # type: ignore + + if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D): + raise ValueError( + "If ax is passed to surface(), it must be created with " + 'projection="3d"' + ) + rgb = kwargs.pop("rgb", None) if rgb is not None and plotfunc.__name__ != "imshow": raise ValueError('The "rgb" keyword is only valid for imshow()') @@ -674,9 +710,10 @@ def newplotfunc( xval = darray[xlab] yval = darray[ylab] - if xval.ndim > 1 or yval.ndim > 1: + if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface": # Passing 2d coordinate values, need to ensure they are transposed the same - # way as darray + # way as darray. + # Also surface plots always need 2d coordinates xval = xval.broadcast_like(darray) yval = yval.broadcast_like(darray) dims = darray.dims @@ -734,8 +771,6 @@ def newplotfunc( # forbid usage of mpl strings raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray") - if subplot_kws is None: - subplot_kws = dict() ax = get_axis(figsize, size, aspect, ax, **subplot_kws) primitive = plotfunc( @@ -755,6 +790,8 @@ def newplotfunc( ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) + if plotfunc.__name__ == "surface": + ax.set_zlabel(label_from_attrs(darray)) if add_colorbar: if add_labels and "label" not in cbar_kwargs: @@ -987,3 +1024,14 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): ax.set_ylim(y[0], y[-1]) return primitive + + +@_plot2d +def surface(x, y, z, ax, **kwargs): + """ + Surface plot of 2d DataArray + + Wraps :func:`matplotlib:mpl_toolkits.mplot3d.axes3d.plot_surface` + """ + primitive = ax.plot_surface(x, y, z, **kwargs) + return primitive diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a83bc28e273..325ea799f28 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -804,6 +804,14 @@ def _process_cmap_cbar_kwargs( cmap_params cbar_kwargs """ + if func.__name__ == "surface": + # Leave user to specify cmap settings for surface plots + kwargs["cmap"] = cmap + return { + k: kwargs.get(k, None) + for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"] + }, {} + cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) if "contour" in func.__name__ and levels is None: diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index f99d678d35e..e414ff0ed0e 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -59,6 +59,9 @@ def LooseVersion(vstring): has_matplotlib, requires_matplotlib = _importorskip("matplotlib") +has_matplotlib_3_3_0, requires_matplotlib_3_3_0 = _importorskip( + "matplotlib", minversion="3.3.0" +) has_scipy, requires_scipy = _importorskip("scipy") has_pydap, requires_pydap = _importorskip("pydap.client") has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index ce8e4bcb65d..e71bcaa359c 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2,6 +2,7 @@ import inspect from copy import copy from datetime import datetime +from typing import Any, Dict, Union import numpy as np import pandas as pd @@ -27,6 +28,7 @@ requires_cartopy, requires_cftime, requires_matplotlib, + requires_matplotlib_3_3_0, requires_nc_time_axis, requires_seaborn, ) @@ -35,6 +37,7 @@ try: import matplotlib as mpl import matplotlib.pyplot as plt + import mpl_toolkits # type: ignore except ImportError: pass @@ -131,8 +134,8 @@ def setup(self): # Remove all matplotlib figures plt.close("all") - def pass_in_axis(self, plotmethod): - fig, axes = plt.subplots(ncols=2) + def pass_in_axis(self, plotmethod, subplot_kw=None): + fig, axes = plt.subplots(ncols=2, subplot_kw=subplot_kw) plotmethod(ax=axes[0]) assert axes[0].has_data() @@ -1106,6 +1109,9 @@ class Common2dMixin: Should have the same name as the method. """ + # Needs to be overridden in TestSurface for facet grid plots + subplot_kws: Union[Dict[Any, Any], None] = None + @pytest.fixture(autouse=True) def setUp(self): da = DataArray( @@ -1421,7 +1427,7 @@ def test_colorbar_kwargs(self): def test_verbose_facetgrid(self): a = easy_array((10, 15, 3)) d = DataArray(a, dims=["y", "x", "z"]) - g = xplt.FacetGrid(d, col="z") + g = xplt.FacetGrid(d, col="z", subplot_kws=self.subplot_kws) g.map_dataarray(self.plotfunc, "x", "y") for ax in g.axes.flat: assert ax.has_data() @@ -1821,6 +1827,95 @@ def test_origin_overrides_xyincrease(self): assert plt.ylim()[0] < 0 +class TestSurface(Common2dMixin, PlotTestCase): + + plotfunc = staticmethod(xplt.surface) + subplot_kws = {"projection": "3d"} + + def test_primitive_artist_returned(self): + artist = self.plotmethod() + assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) + + @pytest.mark.slow + def test_2d_coord_names(self): + self.plotmethod(x="x2d", y="y2d") + # make sure labels came out ok + ax = plt.gca() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() + assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() + + def test_xyincrease_false_changes_axes(self): + # Does not make sense for surface plots + pytest.skip("does not make sense for surface plots") + + def test_xyincrease_true_changes_axes(self): + # Does not make sense for surface plots + pytest.skip("does not make sense for surface plots") + + def test_can_pass_in_axis(self): + self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) + + def test_default_cmap(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_diverging_color_limits(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_colorbar_kwargs(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_cmap_and_color_both(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_seaborn_palette_as_cmap(self): + # seaborn does not work with mpl_toolkits.mplot3d + with pytest.raises(ValueError): + super().test_seaborn_palette_as_cmap() + + # Need to modify this test for surface(), because all subplots should have labels, + # not just left and bottom + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid(self): + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) + + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() + + # Infering labels + g = self.plotfunc(d, col="z", col_wrap=2) + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() + + @requires_matplotlib_3_3_0 + def test_viridis_cmap(self): + return super().test_viridis_cmap() + + @requires_matplotlib_3_3_0 + def test_can_change_default_cmap(self): + return super().test_can_change_default_cmap() + + @requires_matplotlib_3_3_0 + def test_colorbar_default_label(self): + return super().test_colorbar_default_label() + + @requires_matplotlib_3_3_0 + def test_facetgrid_map_only_appends_mappables(self): + return super().test_facetgrid_map_only_appends_mappables() + + class TestFacetGrid(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self):