Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add var and std to weighted computations #5870

Merged
merged 6 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -779,12 +779,18 @@ Weighted objects

core.weighted.DataArrayWeighted
core.weighted.DataArrayWeighted.mean
core.weighted.DataArrayWeighted.std
core.weighted.DataArrayWeighted.sum
core.weighted.DataArrayWeighted.sum_of_squares
core.weighted.DataArrayWeighted.sum_of_weights
core.weighted.DataArrayWeighted.var
core.weighted.DatasetWeighted
core.weighted.DatasetWeighted.mean
core.weighted.DatasetWeighted.std
core.weighted.DatasetWeighted.sum
core.weighted.DatasetWeighted.sum_of_squares
core.weighted.DatasetWeighted.sum_of_weights
core.weighted.DatasetWeighted.var


Coarsen objects
Expand Down
20 changes: 17 additions & 3 deletions doc/user-guide/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ Weighted array reductions

:py:class:`DataArray` and :py:class:`Dataset` objects include :py:meth:`DataArray.weighted`
and :py:meth:`Dataset.weighted` array reduction methods. They currently
support weighted ``sum`` and weighted ``mean``.
support weighted ``sum``, ``mean``, ``std`` and ``var``.

.. ipython:: python

Expand Down Expand Up @@ -298,13 +298,27 @@ The weighted sum corresponds to:
weighted_sum = (prec * weights).sum()
weighted_sum

and the weighted mean to:
the weighted mean to:

.. ipython:: python

weighted_mean = weighted_sum / weights.sum()
weighted_mean

the weighted variance to:

.. ipython:: python

weighted_var = weighted.sum_of_squares() / weights.sum()
weighted_var

and the weighted standard deviation to:

.. ipython:: python

weighted_std = np.sqrt(weighted_var)
weighted_std

However, the functions also take missing values in the data into account:

.. ipython:: python
Expand All @@ -327,7 +341,7 @@ If the weights add up to to 0, ``sum`` returns 0:

data.weighted(weights).sum()

and ``mean`` returns ``NaN``:
and ``mean``, ``std`` and ``var`` return ``NaN``:

.. ipython:: python

Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ v0.19.1 (unreleased)

New Features
~~~~~~~~~~~~
- Add :py:meth:`var`, :py:meth:`std` and :py:meth:`sum_of_squares` to :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted`.
By `Christian Jauvin <https://github.com/cjauvin>`_.
- Added a :py:func:`get_options` method to xarray's root namespace (:issue:`5698`, :pull:`5716`)
By `Pushkar Kopparla <https://github.com/pkopparla>`_.
- Xarray now does a better job rendering variable names that are long LaTeX sequences when plotting (:issue:`5681`, :pull:`5682`).
Expand Down
94 changes: 92 additions & 2 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union

import numpy as np

from . import duck_array_ops
from .computation import dot
from .pycompat import is_duck_dask_array
Expand Down Expand Up @@ -35,7 +37,7 @@
"""

_SUM_OF_WEIGHTS_DOCSTRING = """
Calculate the sum of weights, accounting for missing values in the data
Calculate the sum of weights, accounting for missing values in the data.

Parameters
----------
Expand Down Expand Up @@ -177,13 +179,26 @@ def _sum_of_weights(

return sum_of_weights.where(valid_weights)

def _sum_of_squares(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""

demeaned = da - da.weighted(self.weights).mean(dim=dim)

return self._reduce((demeaned ** 2), self.weights, dim=dim, skipna=skipna)


def _weighted_sum(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a by a weighted ``sum`` along some dimension(s)."""
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""

return self._reduce(da, self.weights, dim=dim, skipna=skipna)

Expand All @@ -201,6 +216,32 @@ def _weighted_mean(

return weighted_sum / sum_of_weights

def _weighted_var(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""

sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna)

sum_of_weights = self._sum_of_weights(da, dim=dim)

return sum_of_squares / sum_of_weights


def _weighted_std(
self,
da: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
) -> "DataArray":
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""

return np.sqrt(self._weighted_var(da, dim, skipna))


def _implementation(self, func, dim, **kwargs):

raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
Expand All @@ -215,6 +256,17 @@ def sum_of_weights(
self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
)

def sum_of_squares(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> T_Xarray:

return self._implementation(
self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def sum(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
Expand All @@ -237,6 +289,28 @@ def mean(
self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def var(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> T_Xarray:

return self._implementation(
self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def std(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> T_Xarray:

return self._implementation(
self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)

def __repr__(self):
"""provide a nice str repr of our Weighted object"""

Expand Down Expand Up @@ -275,6 +349,22 @@ def _inject_docstring(cls, cls_name):
cls=cls_name, fcn="mean", on_zero="NaN"
)

cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls=cls_name, fcn="mean", on_zero="NaN"
)

cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls=cls_name, fcn="sum_of_squares", on_zero="0"
)

cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls=cls_name, fcn="var", on_zero="NaN"
)

cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls=cls_name, fcn="std", on_zero="NaN"
)


_inject_docstring(DataArrayWeighted, "DataArray")
_inject_docstring(DatasetWeighted, "Dataset")
Loading