Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jan 19, 2021
1 parent 04f64c5 commit abcc008
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 45 deletions.
54 changes: 28 additions & 26 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,32 +441,6 @@ def plotmethod(
return newplotfunc


@_dsplot
def quiver(ds, x, y, ax, u, v, **kwargs):
import matplotlib as mpl

if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for quiver plots.")

x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v])

args = [x.values, y.values, u.values, v.values]
hue = kwargs.pop("hue")
if hue:
args.append(ds[hue].values)

# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
cmap_params = kwargs.pop("cmap_params")
if not cmap_params["norm"]:
cmap_params["norm"] = mpl.colors.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)

kwargs.pop("hue_style")
hdl = ax.quiver(*args, **kwargs, **cmap_params)
return hdl


@_dsplot
def scatter(ds, x, y, ax, u, v, **kwargs):
"""
Expand Down Expand Up @@ -520,3 +494,31 @@ def scatter(ds, x, y, ax, u, v, **kwargs):
)

return primitive


@_dsplot
def quiver(ds, x, y, ax, u, v, **kwargs):
import matplotlib as mpl

if x is None or y is None or u is None or v is None:
raise ValueError("Must specify x, y, u, v for quiver plots.")

x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v])

args = [x.values, y.values, u.values, v.values]
hue = kwargs.pop("hue")
cmap_params = kwargs.pop("cmap_params")

if hue:
args.append(ds[hue].values)

# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
if not cmap_params["norm"]:
cmap_params["norm"] = mpl.colors.Normalize(
cmap_params.pop("vmin"), cmap_params.pop("vmax")
)

kwargs.pop("hue_style")
kwargs.setdefault("pivot", "middle")
hdl = ax.quiver(*args, **kwargs, **cmap_params)
return hdl
39 changes: 21 additions & 18 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,19 +349,22 @@ def map_dataset(
kwargs["_is_facetgrid"] = True

if func.__name__ == "quiver" and "scale" not in kwargs:
if "scale_units" in kwargs and kwargs["scale_units"] is not None:
raise NotImplementedError("Can't pass only scale_units.")
# autoscaling
ax = self.axes[0, 0]
magnitude = _get_nice_quiver_magnitude(
self.data[kwargs["u"]], self.data[kwargs["v"]]
)
# matplotlib autoscaling algorithm
span = ax.get_transform().inverted().transform_bbox(ax.bbox).width
npts = self.data.sizes[x] * self.data.sizes[y]
# scale is typical arrow length as a multiple of the arrow width
scale = 1.8 * magnitude * max(10, np.sqrt(npts)) / span
kwargs["scale"] = 1 / scale # TODO: why?
raise ValueError("Please provide scale.")
# TODO: come up with an algorithm for reasonable scale choice
# if "scale_units" in kwargs and kwargs["scale_units"] is not None:
# raise NotImplementedError("Can't pass only scale_units.")
# # autoscaling
# ax = self.axes[0, 0]
# magnitude = _get_nice_quiver_magnitude(
# self.data[kwargs["u"]], self.data[kwargs["v"]]
# )
# # matplotlib autoscaling algorithm
# span = ax.get_transform().inverted().transform_bbox(ax.bbox).width
# npts = self.data.sizes[x] * self.data.sizes[y]
# # scale is typical arrow length as a multiple of the arrow width
# print(magnitude, np.sqrt(npts), span)
# kwargs["scale"] = 1.8 * magnitude * min(10, np.sqrt(npts)) / span
# print(kwargs["scale"])

for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
Expand Down Expand Up @@ -457,17 +460,17 @@ def add_quiverkey(self, u, v, **kwargs):
units = self.data[u].attrs.get("units", "")
self.quiverkey = self.axes.flat[-1].quiverkey(
self._mappables[-1],
X=0.85,
Y=1.03,
X=0.8,
Y=0.9,
U=magnitude,
label=f"{magnitude}\n{units}",
labelpos="E",
coordinates="axes",
coordinates="figure",
)

# TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0
# self._adjust_fig_for_guide(self.quiverkey)

# https://github.com/matplotlib/matplotlib/issues/18530
# self._adjust_fig_for_guide(self.quiverkey.text)
return self

def set_axis_labels(self, x_var=None, y_var=None):
Expand Down
2 changes: 1 addition & 1 deletion xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,6 @@ def _get_nice_quiver_magnitude(u, v):
import matplotlib as mpl

ticker = mpl.ticker.MaxNLocator(3)
median = np.median(np.hypot(u.values, v.values))
median = np.mean(np.hypot(u.values, v.values))
magnitude = ticker.tick_values(0, median)[-2]
return magnitude

0 comments on commit abcc008

Please sign in to comment.