diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml new file mode 100644 index 00000000000..de506546ac9 --- /dev/null +++ b/.github/workflows/benchmarks.yml @@ -0,0 +1,74 @@ +name: Benchmark + +on: + pull_request: + types: [opened, reopened, synchronize, labeled] + workflow_dispatch: + +jobs: + benchmark: + if: ${{ contains( github.event.pull_request.labels.*.name, 'run-benchmark') && github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' }} + name: Linux + runs-on: ubuntu-20.04 + env: + ASV_DIR: "./asv_bench" + + steps: + # We need the full repo to avoid this issue + # https://github.com/actions/checkout/issues/23 + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Miniconda + uses: conda-incubator/setup-miniconda@v2 + with: + # installer-url: https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh + installer-url: https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh + + - name: Setup some dependencies + shell: bash -l {0} + run: | + pip install asv + sudo apt-get update -y + + - name: Run benchmarks + shell: bash -l {0} + id: benchmark + env: + OPENBLAS_NUM_THREADS: 1 + MKL_NUM_THREADS: 1 + OMP_NUM_THREADS: 1 + ASV_FACTOR: 1.5 + ASV_SKIP_SLOW: 1 + run: | + set -x + # ID this runner + asv machine --yes + echo "Baseline: ${{ github.event.pull_request.base.sha }} (${{ github.event.pull_request.base.label }})" + echo "Contender: ${GITHUB_SHA} (${{ github.event.pull_request.head.label }})" + # Use mamba for env creation + # export CONDA_EXE=$(which mamba) + export CONDA_EXE=$(which conda) + # Run benchmarks for current commit against base + ASV_OPTIONS="--split --show-stderr --factor $ASV_FACTOR" + asv continuous $ASV_OPTIONS ${{ github.event.pull_request.base.sha }} ${GITHUB_SHA} \ + | sed "/Traceback \|failed$\|PERFORMANCE DECREASED/ s/^/::error::/" \ + | tee benchmarks.log + # Report and export results for subsequent steps + if grep "Traceback \|failed\|PERFORMANCE DECREASED" benchmarks.log > /dev/null ; then + exit 1 + fi + working-directory: ${{ env.ASV_DIR }} + + - name: Add instructions to artifact + if: always() + run: | + cp benchmarks/README_CI.md benchmarks.log .asv/results/ + working-directory: ${{ env.ASV_DIR }} + + - uses: actions/upload-artifact@v2 + if: always() + with: + name: asv-benchmark-results-${{ runner.os }} + path: ${{ env.ASV_DIR }}/.asv/results diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b1b4127a1b..27f93c8e578 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,6 +16,7 @@ repos: rev: 21.9b0 hooks: - id: black + - id: black-jupyter - repo: https://github.com/keewis/blackdoc rev: v0.3.4 hooks: @@ -30,20 +31,21 @@ repos: # - id: velin # args: ["--write", "--compact"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.910 + rev: v0.910-1 hooks: - id: mypy - # Copied from setup.cfg - exclude: "properties|asv_bench" + # `properies` & `asv_bench` are copied from setup.cfg. + # `_typed_ops.py` is added since otherwise mypy will complain (but notably only in pre-commit) + exclude: "properties|asv_bench|_typed_ops.py" additional_dependencies: [ # Type stubs types-python-dateutil, types-pkg_resources, types-PyYAML, types-pytz, + typing-extensions==3.10.0.0, # Dependencies that are typed numpy, - typing-extensions==3.10.0.0, ] # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 # - repo: https://github.com/asottile/pyupgrade diff --git a/asv_bench/benchmarks/README_CI.md b/asv_bench/benchmarks/README_CI.md new file mode 100644 index 00000000000..9d86cc257ef --- /dev/null +++ b/asv_bench/benchmarks/README_CI.md @@ -0,0 +1,122 @@ +# Benchmark CI + + + + + +## How it works + +The `asv` suite can be run for any PR on GitHub Actions (check workflow `.github/workflows/benchmarks.yml`) by adding a `run-benchmark` label to said PR. This will trigger a job that will run the benchmarking suite for the current PR head (merged commit) against the PR base (usually `main`). + +We use `asv continuous` to run the job, which runs a relative performance measurement. This means that there's no state to be saved and that regressions are only caught in terms of performance ratio (absolute numbers are available but they are not useful since we do not use stable hardware over time). `asv continuous` will: + +* Compile `scikit-image` for _both_ commits. We use `ccache` to speed up the process, and `mamba` is used to create the build environments. +* Run the benchmark suite for both commits, _twice_ (since `processes=2` by default). +* Generate a report table with performance ratios: + * `ratio=1.0` -> performance didn't change. + * `ratio<1.0` -> PR made it slower. + * `ratio>1.0` -> PR made it faster. + +Due to the sensitivity of the test, we cannot guarantee that false positives are not produced. In practice, values between `(0.7, 1.5)` are to be considered part of the measurement noise. When in doubt, running the benchmark suite one more time will provide more information about the test being a false positive or not. + +## Running the benchmarks on GitHub Actions + +1. On a PR, add the label `run-benchmark`. +2. The CI job will be started. Checks will appear in the usual dashboard panel above the comment box. +3. If more commits are added, the label checks will be grouped with the last commit checks _before_ you added the label. +4. Alternatively, you can always go to the `Actions` tab in the repo and [filter for `workflow:Benchmark`](https://github.com/scikit-image/scikit-image/actions?query=workflow%3ABenchmark). Your username will be assigned to the `actor` field, so you can also filter the results with that if you need it. + +## The artifacts + +The CI job will also generate an artifact. This is the `.asv/results` directory compressed in a zip file. Its contents include: + +* `fv-xxxxx-xx/`. A directory for the machine that ran the suite. It contains three files: + * `.json`, `.json`: the benchmark results for each commit, with stats. + * `machine.json`: details about the hardware. +* `benchmarks.json`: metadata about the current benchmark suite. +* `benchmarks.log`: the CI logs for this run. +* This README. + +## Re-running the analysis + +Although the CI logs should be enough to get an idea of what happened (check the table at the end), one can use `asv` to run the analysis routines again. + +1. Uncompress the artifact contents in the repo, under `.asv/results`. This is, you should see `.asv/results/benchmarks.log`, not `.asv/results/something_else/benchmarks.log`. Write down the machine directory name for later. +2. Run `asv show` to see your available results. You will see something like this: + +``` +$> asv show + +Commits with results: + +Machine : Jaimes-MBP +Environment: conda-py3.9-cython-numpy1.20-scipy + + 00875e67 + +Machine : fv-az95-499 +Environment: conda-py3.7-cython-numpy1.17-pooch-scipy + + 8db28f02 + 3a305096 +``` + +3. We are interested in the commits for `fv-az95-499` (the CI machine for this run). We can compare them with `asv compare` and some extra options. `--sort ratio` will show largest ratios first, instead of alphabetical order. `--split` will produce three tables: improved, worsened, no changes. `--factor 1.5` tells `asv` to only complain if deviations are above a 1.5 ratio. `-m` is used to indicate the machine ID (use the one you wrote down in step 1). Finally, specify your commit hashes: baseline first, then contender! + +``` +$> asv compare --sort ratio --split --factor 1.5 -m fv-az95-499 8db28f02 3a305096 + +Benchmarks that have stayed the same: + + before after ratio + [8db28f02] [3a305096] + + n/a n/a n/a benchmark_restoration.RollingBall.time_rollingball_ndim + 1.23±0.04ms 1.37±0.1ms 1.12 benchmark_transform_warp.WarpSuite.time_to_float64(, 128, 3) + 5.07±0.1μs 5.59±0.4μs 1.10 benchmark_transform_warp.ResizeLocalMeanSuite.time_resize_local_mean(, (192, 192, 192), (192, 192, 192)) + 1.23±0.02ms 1.33±0.1ms 1.08 benchmark_transform_warp.WarpSuite.time_same_type(, 128, 3) + 9.45±0.2ms 10.1±0.5ms 1.07 benchmark_rank.Rank3DSuite.time_3d_filters('majority', (32, 32, 32)) + 23.0±0.9ms 24.6±1ms 1.07 benchmark_interpolation.InterpolationResize.time_resize((80, 80, 80), 0, 'symmetric', , True) + 38.7±1ms 41.1±1ms 1.06 benchmark_transform_warp.ResizeLocalMeanSuite.time_resize_local_mean(, (2048, 2048), (192, 192, 192)) + 4.97±0.2μs 5.24±0.2μs 1.05 benchmark_transform_warp.ResizeLocalMeanSuite.time_resize_local_mean(, (2048, 2048), (2048, 2048)) + 4.21±0.2ms 4.42±0.3ms 1.05 benchmark_rank.Rank3DSuite.time_3d_filters('gradient', (32, 32, 32)) + +... +``` + +If you want more details on a specific test, you can use `asv show`. Use `-b pattern` to filter which tests to show, and then specify a commit hash to inspect: + +``` +$> asv show -b time_to_float64 8db28f02 + +Commit: 8db28f02 + +benchmark_transform_warp.WarpSuite.time_to_float64 [fv-az95-499/conda-py3.7-cython-numpy1.17-pooch-scipy] + ok + =============== ============= ========== ============= ========== ============ ========== ============ ========== ============ + -- N / order + --------------- -------------------------------------------------------------------------------------------------------------- + dtype_in 128 / 0 128 / 1 128 / 3 1024 / 0 1024 / 1 1024 / 3 4096 / 0 4096 / 1 4096 / 3 + =============== ============= ========== ============= ========== ============ ========== ============ ========== ============ + numpy.uint8 2.56±0.09ms 523±30μs 1.28±0.05ms 130±3ms 28.7±2ms 81.9±3ms 2.42±0.01s 659±5ms 1.48±0.01s + numpy.uint16 2.48±0.03ms 530±10μs 1.28±0.02ms 130±1ms 30.4±0.7ms 81.1±2ms 2.44±0s 653±3ms 1.47±0.02s + numpy.float32 2.59±0.1ms 518±20μs 1.27±0.01ms 127±3ms 26.6±1ms 74.8±2ms 2.50±0.01s 546±10ms 1.33±0.02s + numpy.float64 2.48±0.04ms 513±50μs 1.23±0.04ms 134±3ms 30.7±2ms 85.4±2ms 2.55±0.01s 632±4ms 1.45±0.01s + =============== ============= ========== ============= ========== ============ ========== ============ ========== ============ + started: 2021-07-06 06:14:36, duration: 1.99m +``` + +## Other details + +### Skipping slow or demanding tests + +To minimize the time required to run the full suite, we trimmed the parameter matrix in some cases and, in others, directly skipped tests that ran for too long or require too much memory. Unlike `pytest`, `asv` does not have a notion of marks. However, you can `raise NotImplementedError` in the setup step to skip a test. In that vein, a new private function is defined at `benchmarks.__init__`: `_skip_slow`. This will check if the `ASV_SKIP_SLOW` environment variable has been defined. If set to `1`, it will raise `NotImplementedError` and skip the test. To implement this behavior in other tests, you can add the following attribute: + +```python +from . import _skip_slow # this function is defined in benchmarks.__init__ + +def time_something_slow(): + pass + +time_something.setup = _skip_slow +``` diff --git a/asv_bench/benchmarks/__init__.py b/asv_bench/benchmarks/__init__.py index b0adb2feafd..02c3896e236 100644 --- a/asv_bench/benchmarks/__init__.py +++ b/asv_bench/benchmarks/__init__.py @@ -1,4 +1,5 @@ import itertools +import os import numpy as np @@ -46,3 +47,21 @@ def randint(low, high=None, size=None, frac_minus=None, seed=0): x.flat[inds] = -1 return x + + +def _skip_slow(): + """ + Use this function to skip slow or highly demanding tests. + + Use it as a `Class.setup` method or a `function.setup` attribute. + + Examples + -------- + >>> from . import _skip_slow + >>> def time_something_slow(): + ... pass + ... + >>> time_something.setup = _skip_slow + """ + if os.environ.get("ASV_SKIP_SLOW", "0") == "1": + raise NotImplementedError("Skipping this test...") diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 308ca2afda4..a4f8db2786b 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -9,7 +9,7 @@ class Combine: def setup(self): """Create 4 datasets with two different variables""" - t_size, x_size, y_size = 100, 900, 800 + t_size, x_size, y_size = 50, 450, 400 t = np.arange(t_size) data = np.random.randn(t_size, x_size, y_size) diff --git a/asv_bench/benchmarks/dataarray_missing.py b/asv_bench/benchmarks/dataarray_missing.py index d79d2558b35..f89fe7f8eb9 100644 --- a/asv_bench/benchmarks/dataarray_missing.py +++ b/asv_bench/benchmarks/dataarray_missing.py @@ -2,12 +2,7 @@ import xarray as xr -from . import randn, requires_dask - -try: - import dask # noqa: F401 -except ImportError: - pass +from . import parameterized, randn, requires_dask def make_bench_data(shape, frac_nan, chunks): @@ -21,54 +16,65 @@ def make_bench_data(shape, frac_nan, chunks): return da -def time_interpolate_na(shape, chunks, method, limit): - if chunks is not None: - requires_dask() - da = make_bench_data(shape, 0.1, chunks=chunks) - actual = da.interpolate_na(dim="time", method="linear", limit=limit) - - if chunks is not None: - actual = actual.compute() - - -time_interpolate_na.param_names = ["shape", "chunks", "method", "limit"] -time_interpolate_na.params = ( - [(3650, 200, 400), (100, 25, 25)], - [None, {"x": 25, "y": 25}], - ["linear", "spline", "quadratic", "cubic"], - [None, 3], -) - - -def time_ffill(shape, chunks, limit): - - da = make_bench_data(shape, 0.1, chunks=chunks) - actual = da.ffill(dim="time", limit=limit) - - if chunks is not None: - actual = actual.compute() - - -time_ffill.param_names = ["shape", "chunks", "limit"] -time_ffill.params = ( - [(3650, 200, 400), (100, 25, 25)], - [None, {"x": 25, "y": 25}], - [None, 3], -) - - -def time_bfill(shape, chunks, limit): - - da = make_bench_data(shape, 0.1, chunks=chunks) - actual = da.bfill(dim="time", limit=limit) - - if chunks is not None: - actual = actual.compute() - - -time_bfill.param_names = ["shape", "chunks", "limit"] -time_bfill.params = ( - [(3650, 200, 400), (100, 25, 25)], - [None, {"x": 25, "y": 25}], - [None, 3], -) +def requires_bottleneck(): + try: + import bottleneck # noqa: F401 + except ImportError: + raise NotImplementedError() + + +class DataArrayMissingInterpolateNA: + def setup(self, shape, chunks, limit): + if chunks is not None: + requires_dask() + self.da = make_bench_data(shape, 0.1, chunks) + + @parameterized( + ["shape", "chunks", "limit"], + ( + [(365, 75, 75)], + [None, {"x": 25, "y": 25}], + [None, 3], + ), + ) + def time_interpolate_na(self, shape, chunks, limit): + actual = self.da.interpolate_na(dim="time", method="linear", limit=limit) + + if chunks is not None: + actual = actual.compute() + + +class DataArrayMissingBottleneck: + def setup(self, shape, chunks, limit): + requires_bottleneck() + if chunks is not None: + requires_dask() + self.da = make_bench_data(shape, 0.1, chunks) + + @parameterized( + ["shape", "chunks", "limit"], + ( + [(365, 75, 75)], + [None, {"x": 25, "y": 25}], + [None, 3], + ), + ) + def time_ffill(self, shape, chunks, limit): + actual = self.da.ffill(dim="time", limit=limit) + + if chunks is not None: + actual = actual.compute() + + @parameterized( + ["shape", "chunks", "limit"], + ( + [(365, 75, 75)], + [None, {"x": 25, "y": 25}], + [None, 3], + ), + ) + def time_bfill(self, shape, chunks, limit): + actual = self.da.ffill(dim="time", limit=limit) + + if chunks is not None: + actual = actual.compute() diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index e99911d752c..6c2e15c54e9 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -5,7 +5,7 @@ import xarray as xr -from . import randint, randn, requires_dask +from . import _skip_slow, randint, randn, requires_dask try: import dask @@ -28,6 +28,9 @@ class IOSingleNetCDF: number = 5 def make_ds(self): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() # single Dataset self.ds = xr.Dataset() @@ -227,6 +230,9 @@ class IOMultipleNetCDF: number = 5 def make_ds(self, nfiles=10): + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() # multiple Dataset self.ds = xr.Dataset() @@ -429,6 +435,10 @@ def time_open_dataset_scipy_with_time_chunks(self): def create_delayed_write(): import dask.array as da + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + vals = da.random.random(300, chunks=(1,)) ds = xr.Dataset({"vals": (["a"], vals)}) return ds.to_netcdf("file.nc", engine="netcdf4", compute=False) @@ -453,6 +463,11 @@ def setup(self): import distributed except ImportError: raise NotImplementedError() + + # TODO: Lazily skipped in CI as it is very demanding and slow. + # Improve times and remove errors. + _skip_slow() + self.client = distributed.Client() self.write = create_delayed_write() diff --git a/asv_bench/benchmarks/import_xarray.py b/asv_bench/benchmarks/import_xarray.py new file mode 100644 index 00000000000..94652e3b82a --- /dev/null +++ b/asv_bench/benchmarks/import_xarray.py @@ -0,0 +1,9 @@ +class ImportXarray: + def setup(self, *args, **kwargs): + def import_xr(): + import xarray # noqa: F401 + + self._import_xr = import_xr + + def time_import_xarray(self): + self._import_xr() diff --git a/asv_bench/benchmarks/indexing.py b/asv_bench/benchmarks/indexing.py index 859c41c913d..15212ec0c61 100644 --- a/asv_bench/benchmarks/indexing.py +++ b/asv_bench/benchmarks/indexing.py @@ -5,11 +5,11 @@ import xarray as xr -from . import randint, randn, requires_dask +from . import parameterized, randint, randn, requires_dask -nx = 3000 -ny = 2000 -nt = 1000 +nx = 2000 +ny = 1000 +nt = 500 basic_indexes = { "1slice": {"x": slice(0, 3)}, @@ -21,7 +21,7 @@ "1slice": xr.DataArray(randn((3, ny), frac_nan=0.1), dims=["x", "y"]), "1slice-1scalar": xr.DataArray(randn(int(ny / 3) + 1, frac_nan=0.1), dims=["y"]), "2slicess-1scalar": xr.DataArray( - randn(int((nx - 6) / 3), frac_nan=0.1), dims=["x"] + randn(np.empty(nx)[slice(3, -3, 3)].size, frac_nan=0.1), dims=["x"] ), } @@ -51,7 +51,7 @@ } vectorized_assignment_values = { - "1-1d": xr.DataArray(randn((400, 2000)), dims=["a", "y"], coords={"a": randn(400)}), + "1-1d": xr.DataArray(randn((400, ny)), dims=["a", "y"], coords={"a": randn(400)}), "2-1d": xr.DataArray(randn(400), dims=["a"], coords={"a": randn(400)}), "3-2d": xr.DataArray( randn((4, 100)), dims=["a", "b"], coords={"a": randn(4), "b": randn(100)} @@ -77,50 +77,38 @@ def setup(self, key): class Indexing(Base): + @parameterized(["key"], [list(basic_indexes.keys())]) def time_indexing_basic(self, key): self.ds.isel(**basic_indexes[key]).load() - time_indexing_basic.param_names = ["key"] - time_indexing_basic.params = [list(basic_indexes.keys())] - + @parameterized(["key"], [list(outer_indexes.keys())]) def time_indexing_outer(self, key): self.ds.isel(**outer_indexes[key]).load() - time_indexing_outer.param_names = ["key"] - time_indexing_outer.params = [list(outer_indexes.keys())] - + @parameterized(["key"], [list(vectorized_indexes.keys())]) def time_indexing_vectorized(self, key): self.ds.isel(**vectorized_indexes[key]).load() - time_indexing_vectorized.param_names = ["key"] - time_indexing_vectorized.params = [list(vectorized_indexes.keys())] - class Assignment(Base): + @parameterized(["key"], [list(basic_indexes.keys())]) def time_assignment_basic(self, key): ind = basic_indexes[key] val = basic_assignment_values[key] self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val - time_assignment_basic.param_names = ["key"] - time_assignment_basic.params = [list(basic_indexes.keys())] - + @parameterized(["key"], [list(outer_indexes.keys())]) def time_assignment_outer(self, key): ind = outer_indexes[key] val = outer_assignment_values[key] self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val - time_assignment_outer.param_names = ["key"] - time_assignment_outer.params = [list(outer_indexes.keys())] - + @parameterized(["key"], [list(vectorized_indexes.keys())]) def time_assignment_vectorized(self, key): ind = vectorized_indexes[key] val = vectorized_assignment_values[key] self.ds["var1"][ind.get("x", slice(None)), ind.get("y", slice(None))] = val - time_assignment_vectorized.param_names = ["key"] - time_assignment_vectorized.params = [list(vectorized_indexes.keys())] - class IndexingDask(Indexing): def setup(self, key): diff --git a/asv_bench/benchmarks/interp.py b/asv_bench/benchmarks/interp.py index cded900ebbc..4b6691bcc0a 100644 --- a/asv_bench/benchmarks/interp.py +++ b/asv_bench/benchmarks/interp.py @@ -5,21 +5,17 @@ from . import parameterized, randn, requires_dask -nx = 3000 -long_nx = 30000000 -ny = 2000 -nt = 1000 -window = 20 +nx = 1500 +ny = 1000 +nt = 500 randn_xy = randn((nx, ny), frac_nan=0.1) randn_xt = randn((nx, nt)) randn_t = randn((nt,)) -randn_long = randn((long_nx,), frac_nan=0.1) - new_x_short = np.linspace(0.3 * nx, 0.7 * nx, 100) -new_x_long = np.linspace(0.3 * nx, 0.7 * nx, 1000) -new_y_long = np.linspace(0.1, 0.9, 1000) +new_x_long = np.linspace(0.3 * nx, 0.7 * nx, 500) +new_y_long = np.linspace(0.1, 0.9, 500) class Interpolation: diff --git a/asv_bench/benchmarks/pandas.py b/asv_bench/benchmarks/pandas.py index 42ef18ac0c2..8aaa515d417 100644 --- a/asv_bench/benchmarks/pandas.py +++ b/asv_bench/benchmarks/pandas.py @@ -1,6 +1,8 @@ import numpy as np import pandas as pd +import xarray as xr + from . import parameterized @@ -20,5 +22,5 @@ def setup(self, dtype, subset): self.series = series @parameterized(["dtype", "subset"], ([int, float], [True, False])) - def time_to_xarray(self, dtype, subset): - self.series.to_xarray() + def time_from_series(self, dtype, subset): + xr.DataArray.from_series(self.series) diff --git a/asv_bench/benchmarks/reindexing.py b/asv_bench/benchmarks/reindexing.py index fe4fa500c09..9d0767fc3b3 100644 --- a/asv_bench/benchmarks/reindexing.py +++ b/asv_bench/benchmarks/reindexing.py @@ -4,38 +4,42 @@ from . import requires_dask +ntime = 500 +nx = 50 +ny = 50 + class Reindex: def setup(self): - data = np.random.RandomState(0).randn(1000, 100, 100) + data = np.random.RandomState(0).randn(ntime, nx, ny) self.ds = xr.Dataset( {"temperature": (("time", "x", "y"), data)}, - coords={"time": np.arange(1000), "x": np.arange(100), "y": np.arange(100)}, + coords={"time": np.arange(ntime), "x": np.arange(nx), "y": np.arange(ny)}, ) def time_1d_coarse(self): - self.ds.reindex(time=np.arange(0, 1000, 5)).load() + self.ds.reindex(time=np.arange(0, ntime, 5)).load() def time_1d_fine_all_found(self): - self.ds.reindex(time=np.arange(0, 1000, 0.5), method="nearest").load() + self.ds.reindex(time=np.arange(0, ntime, 0.5), method="nearest").load() def time_1d_fine_some_missing(self): self.ds.reindex( - time=np.arange(0, 1000, 0.5), method="nearest", tolerance=0.1 + time=np.arange(0, ntime, 0.5), method="nearest", tolerance=0.1 ).load() def time_2d_coarse(self): - self.ds.reindex(x=np.arange(0, 100, 2), y=np.arange(0, 100, 2)).load() + self.ds.reindex(x=np.arange(0, nx, 2), y=np.arange(0, ny, 2)).load() def time_2d_fine_all_found(self): self.ds.reindex( - x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5), method="nearest" + x=np.arange(0, nx, 0.5), y=np.arange(0, ny, 0.5), method="nearest" ).load() def time_2d_fine_some_missing(self): self.ds.reindex( - x=np.arange(0, 100, 0.5), - y=np.arange(0, 100, 0.5), + x=np.arange(0, nx, 0.5), + y=np.arange(0, ny, 0.5), method="nearest", tolerance=0.1, ).load() diff --git a/asv_bench/benchmarks/repr.py b/asv_bench/benchmarks/repr.py index 405f6cd0530..4bf2ace352d 100644 --- a/asv_bench/benchmarks/repr.py +++ b/asv_bench/benchmarks/repr.py @@ -28,9 +28,9 @@ def time_repr_html(self): class ReprMultiIndex: def setup(self): index = pd.MultiIndex.from_product( - [range(10000), range(10000)], names=("level_0", "level_1") + [range(1000), range(1000)], names=("level_0", "level_1") ) - series = pd.Series(range(100000000), index=index) + series = pd.Series(range(1000 * 1000), index=index) self.da = xr.DataArray(series) def time_repr(self): diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 93c3c6aed4e..f0e18bf2153 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -5,10 +5,10 @@ from . import parameterized, randn, requires_dask -nx = 3000 -long_nx = 30000000 -ny = 2000 -nt = 1000 +nx = 300 +long_nx = 30000 +ny = 200 +nt = 100 window = 20 randn_xy = randn((nx, ny), frac_nan=0.1) @@ -44,21 +44,21 @@ def time_rolling(self, func, center): def time_rolling_long(self, func, pandas): if pandas: se = self.da_long.to_series() - getattr(se.rolling(window=window), func)() + getattr(se.rolling(window=window, min_periods=window), func)() else: - getattr(self.da_long.rolling(x=window), func)().load() + getattr(self.da_long.rolling(x=window, min_periods=window), func)().load() - @parameterized(["window_", "min_periods"], ([20, 40], [5, None])) + @parameterized(["window_", "min_periods"], ([20, 40], [5, 5])) def time_rolling_np(self, window_, min_periods): self.ds.rolling(x=window_, center=False, min_periods=min_periods).reduce( - getattr(np, "nanmean") + getattr(np, "nansum") ).load() - @parameterized(["center", "stride"], ([True, False], [1, 200])) + @parameterized(["center", "stride"], ([True, False], [1, 1])) def time_rolling_construct(self, center, stride): self.ds.rolling(x=window, center=center).construct( "window_dim", stride=stride - ).mean(dim="window_dim").load() + ).sum(dim="window_dim").load() class RollingDask(Rolling): diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py index 8d0c3932870..2c5b7ca7821 100644 --- a/asv_bench/benchmarks/unstacking.py +++ b/asv_bench/benchmarks/unstacking.py @@ -7,7 +7,7 @@ class Unstacking: def setup(self): - data = np.random.RandomState(0).randn(500, 1000) + data = np.random.RandomState(0).randn(250, 500) self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...]) self.da_missing = self.da_full[:-1] self.df_missing = self.da_missing.to_pandas() @@ -26,4 +26,4 @@ class UnstackingDask(Unstacking): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.da_full = self.da_full.chunk({"flat_dim": 50}) + self.da_full = self.da_full.chunk({"flat_dim": 25}) diff --git a/doc/conf.py b/doc/conf.py index 0a6d1504161..77387dfd965 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -260,8 +260,8 @@ ogp_image = "https://xarray.pydata.org/en/stable/_static/dataset-diagram-logo.png" ogp_custom_meta_tags = [ '', - '', + '', + '', ] # Redirects for pages that were moved to new locations diff --git a/doc/examples/ERA5-GRIB-example.ipynb b/doc/examples/ERA5-GRIB-example.ipynb index caa702ebe53..5d09f1a7431 100644 --- a/doc/examples/ERA5-GRIB-example.ipynb +++ b/doc/examples/ERA5-GRIB-example.ipynb @@ -37,7 +37,7 @@ "metadata": {}, "outputs": [], "source": [ - "ds = xr.tutorial.load_dataset('era5-2mt-2019-03-uk.grib', engine='cfgrib')" + "ds = xr.tutorial.load_dataset(\"era5-2mt-2019-03-uk.grib\", engine=\"cfgrib\")" ] }, { @@ -72,11 +72,14 @@ "source": [ "import cartopy.crs as ccrs\n", "import cartopy\n", - "fig = plt.figure(figsize=(10,10))\n", + "\n", + "fig = plt.figure(figsize=(10, 10))\n", "ax = plt.axes(projection=ccrs.Robinson())\n", - "ax.coastlines(resolution='10m')\n", - "plot = ds.t2m[0].plot(cmap=plt.cm.coolwarm, transform=ccrs.PlateCarree(), cbar_kwargs={'shrink':0.6})\n", - "plt.title('ERA5 - 2m temperature British Isles March 2019')" + "ax.coastlines(resolution=\"10m\")\n", + "plot = ds.t2m[0].plot(\n", + " cmap=plt.cm.coolwarm, transform=ccrs.PlateCarree(), cbar_kwargs={\"shrink\": 0.6}\n", + ")\n", + "plt.title(\"ERA5 - 2m temperature British Isles March 2019\")" ] }, { @@ -92,8 +95,8 @@ "metadata": {}, "outputs": [], "source": [ - "ds.t2m.sel(longitude=0,latitude=51.5).plot()\n", - "plt.title('ERA5 - London 2m temperature March 2019')" + "ds.t2m.sel(longitude=0, latitude=51.5).plot()\n", + "plt.title(\"ERA5 - London 2m temperature March 2019\")" ] } ], diff --git a/doc/examples/ROMS_ocean_model.ipynb b/doc/examples/ROMS_ocean_model.ipynb index b699c4d5ba9..82d7a8d58af 100644 --- a/doc/examples/ROMS_ocean_model.ipynb +++ b/doc/examples/ROMS_ocean_model.ipynb @@ -26,6 +26,7 @@ "import cartopy.crs as ccrs\n", "import cartopy.feature as cfeature\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "import xarray as xr" @@ -73,9 +74,9 @@ "outputs": [], "source": [ "# load in the file\n", - "ds = xr.tutorial.open_dataset('ROMS_example.nc', chunks={'ocean_time': 1})\n", + "ds = xr.tutorial.open_dataset(\"ROMS_example.nc\", chunks={\"ocean_time\": 1})\n", "\n", - "# This is a way to turn on chunking and lazy evaluation. Opening with mfdataset, or \n", + "# This is a way to turn on chunking and lazy evaluation. Opening with mfdataset, or\n", "# setting the chunking in the open_dataset would also achive this.\n", "ds" ] @@ -105,12 +106,12 @@ "source": [ "if ds.Vtransform == 1:\n", " Zo_rho = ds.hc * (ds.s_rho - ds.Cs_r) + ds.Cs_r * ds.h\n", - " z_rho = Zo_rho + ds.zeta * (1 + Zo_rho/ds.h)\n", + " z_rho = Zo_rho + ds.zeta * (1 + Zo_rho / ds.h)\n", "elif ds.Vtransform == 2:\n", " Zo_rho = (ds.hc * ds.s_rho + ds.Cs_r * ds.h) / (ds.hc + ds.h)\n", " z_rho = ds.zeta + (ds.zeta + ds.h) * Zo_rho\n", "\n", - "ds.coords['z_rho'] = z_rho.transpose() # needing transpose seems to be an xarray bug\n", + "ds.coords[\"z_rho\"] = z_rho.transpose() # needing transpose seems to be an xarray bug\n", "ds.salt" ] }, @@ -148,7 +149,7 @@ "outputs": [], "source": [ "section = ds.salt.isel(xi_rho=50, eta_rho=slice(0, 167), ocean_time=0)\n", - "section.plot(x='lon_rho', y='z_rho', figsize=(15, 6), clim=(25, 35))\n", + "section.plot(x=\"lon_rho\", y=\"z_rho\", figsize=(15, 6), clim=(25, 35))\n", "plt.ylim([-100, 1]);" ] }, @@ -167,7 +168,7 @@ "metadata": {}, "outputs": [], "source": [ - "ds.salt.isel(s_rho=-1, ocean_time=0).plot(x='lon_rho', y='lat_rho')" + "ds.salt.isel(s_rho=-1, ocean_time=0).plot(x=\"lon_rho\", y=\"lat_rho\")" ] }, { @@ -186,11 +187,13 @@ "proj = ccrs.LambertConformal(central_longitude=-92, central_latitude=29)\n", "fig = plt.figure(figsize=(15, 5))\n", "ax = plt.axes(projection=proj)\n", - "ds.salt.isel(s_rho=-1, ocean_time=0).plot(x='lon_rho', y='lat_rho', \n", - " transform=ccrs.PlateCarree())\n", + "ds.salt.isel(s_rho=-1, ocean_time=0).plot(\n", + " x=\"lon_rho\", y=\"lat_rho\", transform=ccrs.PlateCarree()\n", + ")\n", "\n", - "coast_10m = cfeature.NaturalEarthFeature('physical', 'land', '10m',\n", - " edgecolor='k', facecolor='0.8')\n", + "coast_10m = cfeature.NaturalEarthFeature(\n", + " \"physical\", \"land\", \"10m\", edgecolor=\"k\", facecolor=\"0.8\"\n", + ")\n", "ax.add_feature(coast_10m)" ] }, diff --git a/doc/examples/apply_ufunc_vectorize_1d.ipynb b/doc/examples/apply_ufunc_vectorize_1d.ipynb index e9a48d70173..d1d6a52919c 100644 --- a/doc/examples/apply_ufunc_vectorize_1d.ipynb +++ b/doc/examples/apply_ufunc_vectorize_1d.ipynb @@ -674,7 +674,9 @@ " exclude_dims=set((dim,)), # dimensions allowed to change size. Must be a set!\n", " # vectorize=True, # not needed since numba takes care of vectorizing\n", " dask=\"parallelized\",\n", - " output_dtypes=[data.dtype], # one per output; could also be float or np.dtype(\"float64\")\n", + " output_dtypes=[\n", + " data.dtype\n", + " ], # one per output; could also be float or np.dtype(\"float64\")\n", " ).rename({\"__newdim__\": dim})\n", " interped[dim] = newdim # need to add this manually\n", "\n", diff --git a/doc/examples/monthly-means.ipynb b/doc/examples/monthly-means.ipynb index 3490fc9a4fe..fd31e21a872 100644 --- a/doc/examples/monthly-means.ipynb +++ b/doc/examples/monthly-means.ipynb @@ -29,7 +29,7 @@ "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", - "import matplotlib.pyplot as plt " + "import matplotlib.pyplot as plt" ] }, { @@ -50,7 +50,7 @@ }, "outputs": [], "source": [ - "ds = xr.tutorial.open_dataset('rasm').load()\n", + "ds = xr.tutorial.open_dataset(\"rasm\").load()\n", "ds" ] }, @@ -88,13 +88,15 @@ "outputs": [], "source": [ "# Calculate the weights by grouping by 'time.season'.\n", - "weights = month_length.groupby('time.season') / month_length.groupby('time.season').sum()\n", + "weights = (\n", + " month_length.groupby(\"time.season\") / month_length.groupby(\"time.season\").sum()\n", + ")\n", "\n", "# Test that the sum of the weights for each season is 1.0\n", - "np.testing.assert_allclose(weights.groupby('time.season').sum().values, np.ones(4))\n", + "np.testing.assert_allclose(weights.groupby(\"time.season\").sum().values, np.ones(4))\n", "\n", "# Calculate the weighted average\n", - "ds_weighted = (ds * weights).groupby('time.season').sum(dim='time')" + "ds_weighted = (ds * weights).groupby(\"time.season\").sum(dim=\"time\")" ] }, { @@ -123,7 +125,7 @@ "outputs": [], "source": [ "# only used for comparisons\n", - "ds_unweighted = ds.groupby('time.season').mean('time')\n", + "ds_unweighted = ds.groupby(\"time.season\").mean(\"time\")\n", "ds_diff = ds_weighted - ds_unweighted" ] }, @@ -139,39 +141,54 @@ "outputs": [], "source": [ "# Quick plot to show the results\n", - "notnull = pd.notnull(ds_unweighted['Tair'][0])\n", - "\n", - "fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14,12))\n", - "for i, season in enumerate(('DJF', 'MAM', 'JJA', 'SON')):\n", - " ds_weighted['Tair'].sel(season=season).where(notnull).plot.pcolormesh(\n", - " ax=axes[i, 0], vmin=-30, vmax=30, cmap='Spectral_r', \n", - " add_colorbar=True, extend='both')\n", - " \n", - " ds_unweighted['Tair'].sel(season=season).where(notnull).plot.pcolormesh(\n", - " ax=axes[i, 1], vmin=-30, vmax=30, cmap='Spectral_r', \n", - " add_colorbar=True, extend='both')\n", - "\n", - " ds_diff['Tair'].sel(season=season).where(notnull).plot.pcolormesh(\n", - " ax=axes[i, 2], vmin=-0.1, vmax=.1, cmap='RdBu_r',\n", - " add_colorbar=True, extend='both')\n", + "notnull = pd.notnull(ds_unweighted[\"Tair\"][0])\n", + "\n", + "fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(14, 12))\n", + "for i, season in enumerate((\"DJF\", \"MAM\", \"JJA\", \"SON\")):\n", + " ds_weighted[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n", + " ax=axes[i, 0],\n", + " vmin=-30,\n", + " vmax=30,\n", + " cmap=\"Spectral_r\",\n", + " add_colorbar=True,\n", + " extend=\"both\",\n", + " )\n", + "\n", + " ds_unweighted[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n", + " ax=axes[i, 1],\n", + " vmin=-30,\n", + " vmax=30,\n", + " cmap=\"Spectral_r\",\n", + " add_colorbar=True,\n", + " extend=\"both\",\n", + " )\n", + "\n", + " ds_diff[\"Tair\"].sel(season=season).where(notnull).plot.pcolormesh(\n", + " ax=axes[i, 2],\n", + " vmin=-0.1,\n", + " vmax=0.1,\n", + " cmap=\"RdBu_r\",\n", + " add_colorbar=True,\n", + " extend=\"both\",\n", + " )\n", "\n", " axes[i, 0].set_ylabel(season)\n", - " axes[i, 1].set_ylabel('')\n", - " axes[i, 2].set_ylabel('')\n", + " axes[i, 1].set_ylabel(\"\")\n", + " axes[i, 2].set_ylabel(\"\")\n", "\n", "for ax in axes.flat:\n", " ax.axes.get_xaxis().set_ticklabels([])\n", " ax.axes.get_yaxis().set_ticklabels([])\n", - " ax.axes.axis('tight')\n", - " ax.set_xlabel('')\n", - " \n", - "axes[0, 0].set_title('Weighted by DPM')\n", - "axes[0, 1].set_title('Equal Weighting')\n", - "axes[0, 2].set_title('Difference')\n", - " \n", + " ax.axes.axis(\"tight\")\n", + " ax.set_xlabel(\"\")\n", + "\n", + "axes[0, 0].set_title(\"Weighted by DPM\")\n", + "axes[0, 1].set_title(\"Equal Weighting\")\n", + "axes[0, 2].set_title(\"Difference\")\n", + "\n", "plt.tight_layout()\n", "\n", - "fig.suptitle('Seasonal Surface Air Temperature', fontsize=16, y=1.02)" + "fig.suptitle(\"Seasonal Surface Air Temperature\", fontsize=16, y=1.02)" ] }, { @@ -186,18 +203,20 @@ "outputs": [], "source": [ "# Wrap it into a simple function\n", - "def season_mean(ds, calendar='standard'):\n", + "def season_mean(ds, calendar=\"standard\"):\n", " # Make a DataArray with the number of days in each month, size = len(time)\n", " month_length = ds.time.dt.days_in_month\n", "\n", " # Calculate the weights by grouping by 'time.season'\n", - " weights = month_length.groupby('time.season') / month_length.groupby('time.season').sum()\n", + " weights = (\n", + " month_length.groupby(\"time.season\") / month_length.groupby(\"time.season\").sum()\n", + " )\n", "\n", " # Test that the sum of the weights for each season is 1.0\n", - " np.testing.assert_allclose(weights.groupby('time.season').sum().values, np.ones(4))\n", + " np.testing.assert_allclose(weights.groupby(\"time.season\").sum().values, np.ones(4))\n", "\n", " # Calculate the weighted average\n", - " return (ds * weights).groupby('time.season').sum(dim='time')" + " return (ds * weights).groupby(\"time.season\").sum(dim=\"time\")" ] } ], diff --git a/doc/examples/multidimensional-coords.ipynb b/doc/examples/multidimensional-coords.ipynb index 3327192e324..f095d1137de 100644 --- a/doc/examples/multidimensional-coords.ipynb +++ b/doc/examples/multidimensional-coords.ipynb @@ -48,7 +48,7 @@ }, "outputs": [], "source": [ - "ds = xr.tutorial.open_dataset('rasm').load()\n", + "ds = xr.tutorial.open_dataset(\"rasm\").load()\n", "ds" ] }, @@ -94,7 +94,7 @@ }, "outputs": [], "source": [ - "fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(14,4))\n", + "fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(14, 4))\n", "ds.xc.plot(ax=ax1)\n", "ds.yc.plot(ax=ax2)" ] @@ -140,12 +140,14 @@ }, "outputs": [], "source": [ - "plt.figure(figsize=(14,6))\n", + "plt.figure(figsize=(14, 6))\n", "ax = plt.axes(projection=ccrs.PlateCarree())\n", "ax.set_global()\n", - "ds.Tair[0].plot.pcolormesh(ax=ax, transform=ccrs.PlateCarree(), x='xc', y='yc', add_colorbar=False)\n", + "ds.Tair[0].plot.pcolormesh(\n", + " ax=ax, transform=ccrs.PlateCarree(), x=\"xc\", y=\"yc\", add_colorbar=False\n", + ")\n", "ax.coastlines()\n", - "ax.set_ylim([0,90]);" + "ax.set_ylim([0, 90]);" ] }, { @@ -169,11 +171,13 @@ "outputs": [], "source": [ "# define two-degree wide latitude bins\n", - "lat_bins = np.arange(0,91,2)\n", + "lat_bins = np.arange(0, 91, 2)\n", "# define a label for each bin corresponding to the central latitude\n", - "lat_center = np.arange(1,90,2)\n", + "lat_center = np.arange(1, 90, 2)\n", "# group according to those bins and take the mean\n", - "Tair_lat_mean = ds.Tair.groupby_bins('xc', lat_bins, labels=lat_center).mean(dim=xr.ALL_DIMS)\n", + "Tair_lat_mean = ds.Tair.groupby_bins(\"xc\", lat_bins, labels=lat_center).mean(\n", + " dim=xr.ALL_DIMS\n", + ")\n", "# plot the result\n", "Tair_lat_mean.plot()" ] diff --git a/doc/examples/visualization_gallery.ipynb b/doc/examples/visualization_gallery.ipynb index 831f162d998..e6fa564db0d 100644 --- a/doc/examples/visualization_gallery.ipynb +++ b/doc/examples/visualization_gallery.ipynb @@ -18,6 +18,7 @@ "import cartopy.crs as ccrs\n", "import matplotlib.pyplot as plt\n", "import xarray as xr\n", + "\n", "%matplotlib inline" ] }, @@ -34,7 +35,7 @@ "metadata": {}, "outputs": [], "source": [ - "ds = xr.tutorial.load_dataset('air_temperature')" + "ds = xr.tutorial.load_dataset(\"air_temperature\")" ] }, { @@ -62,10 +63,13 @@ "# This is the map projection we want to plot *onto*\n", "map_proj = ccrs.LambertConformal(central_longitude=-95, central_latitude=45)\n", "\n", - "p = air.plot(transform=ccrs.PlateCarree(), # the data's projection\n", - " col='time', col_wrap=1, # multiplot settings\n", - " aspect=ds.dims['lon'] / ds.dims['lat'], # for a sensible figsize\n", - " subplot_kws={'projection': map_proj}) # the plot's projection\n", + "p = air.plot(\n", + " transform=ccrs.PlateCarree(), # the data's projection\n", + " col=\"time\",\n", + " col_wrap=1, # multiplot settings\n", + " aspect=ds.dims[\"lon\"] / ds.dims[\"lat\"], # for a sensible figsize\n", + " subplot_kws={\"projection\": map_proj},\n", + ") # the plot's projection\n", "\n", "# We have to set the map's options on all axes\n", "for ax in p.axes.flat:\n", @@ -93,25 +97,25 @@ "f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 6))\n", "\n", "# The first plot (in kelvins) chooses \"viridis\" and uses the data's min/max\n", - "air.plot(ax=ax1, cbar_kwargs={'label': 'K'})\n", - "ax1.set_title('Kelvins: default')\n", - "ax2.set_xlabel('')\n", + "air.plot(ax=ax1, cbar_kwargs={\"label\": \"K\"})\n", + "ax1.set_title(\"Kelvins: default\")\n", + "ax2.set_xlabel(\"\")\n", "\n", "# The second plot (in celsius) now chooses \"BuRd\" and centers min/max around 0\n", "airc = air - 273.15\n", - "airc.plot(ax=ax2, cbar_kwargs={'label': '°C'})\n", - "ax2.set_title('Celsius: default')\n", - "ax2.set_xlabel('')\n", - "ax2.set_ylabel('')\n", + "airc.plot(ax=ax2, cbar_kwargs={\"label\": \"°C\"})\n", + "ax2.set_title(\"Celsius: default\")\n", + "ax2.set_xlabel(\"\")\n", + "ax2.set_ylabel(\"\")\n", "\n", "# The center doesn't have to be 0\n", - "air.plot(ax=ax3, center=273.15, cbar_kwargs={'label': 'K'})\n", - "ax3.set_title('Kelvins: center=273.15')\n", + "air.plot(ax=ax3, center=273.15, cbar_kwargs={\"label\": \"K\"})\n", + "ax3.set_title(\"Kelvins: center=273.15\")\n", "\n", "# Or it can be ignored\n", - "airc.plot(ax=ax4, center=False, cbar_kwargs={'label': '°C'})\n", - "ax4.set_title('Celsius: center=False')\n", - "ax4.set_ylabel('')\n", + "airc.plot(ax=ax4, center=False, cbar_kwargs={\"label\": \"°C\"})\n", + "ax4.set_title(\"Celsius: center=False\")\n", + "ax4.set_ylabel(\"\")\n", "\n", "# Make it nice\n", "plt.tight_layout()" @@ -143,9 +147,10 @@ "\n", "# Plot data\n", "air2d.plot(ax=ax1, levels=levels)\n", - "air2d.plot(ax=ax2, levels=levels, cbar_kwargs={'ticks': levels})\n", - "air2d.plot(ax=ax3, levels=levels, cbar_kwargs={'ticks': levels,\n", - " 'spacing': 'proportional'})\n", + "air2d.plot(ax=ax2, levels=levels, cbar_kwargs={\"ticks\": levels})\n", + "air2d.plot(\n", + " ax=ax3, levels=levels, cbar_kwargs={\"ticks\": levels, \"spacing\": \"proportional\"}\n", + ")\n", "\n", "# Show plots\n", "plt.tight_layout()" @@ -178,12 +183,12 @@ "isel_lats = [10, 15, 20]\n", "\n", "# Temperature vs longitude plot - illustrates the \"hue\" kwarg\n", - "air.isel(time=0, lat=isel_lats).plot.line(ax=ax1, hue='lat')\n", - "ax1.set_ylabel('°C')\n", + "air.isel(time=0, lat=isel_lats).plot.line(ax=ax1, hue=\"lat\")\n", + "ax1.set_ylabel(\"°C\")\n", "\n", "# Temperature vs time plot - illustrates the \"x\" and \"add_legend\" kwargs\n", - "air.isel(lon=30, lat=isel_lats).plot.line(ax=ax2, x='time', add_legend=False)\n", - "ax2.set_ylabel('')\n", + "air.isel(lon=30, lat=isel_lats).plot.line(ax=ax2, x=\"time\", add_legend=False)\n", + "ax2.set_ylabel(\"\")\n", "\n", "# Show\n", "plt.tight_layout()" @@ -216,12 +221,12 @@ "\n", "# The data is in UTM projection. We have to set it manually until\n", "# https://github.com/SciTools/cartopy/issues/813 is implemented\n", - "crs = ccrs.UTM('18')\n", + "crs = ccrs.UTM(\"18\")\n", "\n", "# Plot on a map\n", "ax = plt.subplot(projection=crs)\n", - "da.plot.imshow(ax=ax, rgb='band', transform=crs)\n", - "ax.coastlines('10m', color='r')" + "da.plot.imshow(ax=ax, rgb=\"band\", transform=crs)\n", + "ax.coastlines(\"10m\", color=\"r\")" ] }, { @@ -250,20 +255,27 @@ "\n", "da = xr.tutorial.open_rasterio(\"RGB.byte\")\n", "\n", - "x, y = np.meshgrid(da['x'], da['y'])\n", + "x, y = np.meshgrid(da[\"x\"], da[\"y\"])\n", "transformer = Transformer.from_crs(da.crs, \"EPSG:4326\", always_xy=True)\n", "lon, lat = transformer.transform(x, y)\n", - "da.coords['lon'] = (('y', 'x'), lon)\n", - "da.coords['lat'] = (('y', 'x'), lat)\n", + "da.coords[\"lon\"] = ((\"y\", \"x\"), lon)\n", + "da.coords[\"lat\"] = ((\"y\", \"x\"), lat)\n", "\n", "# Compute a greyscale out of the rgb image\n", - "greyscale = da.mean(dim='band')\n", + "greyscale = da.mean(dim=\"band\")\n", "\n", "# Plot on a map\n", "ax = plt.subplot(projection=ccrs.PlateCarree())\n", - "greyscale.plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree(),\n", - " cmap='Greys_r', shading=\"auto\",add_colorbar=False)\n", - "ax.coastlines('10m', color='r')" + "greyscale.plot(\n", + " ax=ax,\n", + " x=\"lon\",\n", + " y=\"lat\",\n", + " transform=ccrs.PlateCarree(),\n", + " cmap=\"Greys_r\",\n", + " shading=\"auto\",\n", + " add_colorbar=False,\n", + ")\n", + "ax.coastlines(\"10m\", color=\"r\")" ] } ], diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index dc5c4915f3e..6908c6ff535 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -742,6 +742,12 @@ GeoTIFFs and other gridded raster datasets can be opened using `rasterio`_, if rasterio is installed. Here is an example of how to use :py:func:`open_rasterio` to read one of rasterio's `test files`_: +.. deprecated:: 0.19.1 + + Deprecated in favor of rioxarray. + For information about transitioning, see: + https://corteva.github.io/rioxarray/stable/getting_started/getting_started.html + .. ipython:: :verbatim: @@ -769,12 +775,6 @@ coordinates defined in the file's projection provided by the ``crs`` attribute. See :ref:`/examples/visualization_gallery.ipynb#Parsing-rasterio-geocoordinates` for an example of how to convert these to longitudes and latitudes. -.. warning:: - - This feature has been added in xarray v0.9.6 and should still be - considered experimental. Please report any bugs you may find - on xarray's github repository. - Additionally, you can use `rioxarray`_ for reading in GeoTiff, netCDF or other GDAL readable raster data using `rasterio`_ as well as for exporting to a geoTIFF. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f5130e6ce6d..df2afdd9c4d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -19,6 +19,7 @@ What's New v0.19.1 (unreleased) --------------------- +.. TODO(by keewis): update deprecations if we decide to skip 0.19.1 New Features ~~~~~~~~~~~~ @@ -26,14 +27,15 @@ New Features By `Pushkar Kopparla `_. - Xarray now does a better job rendering variable names that are long LaTeX sequences when plotting (:issue:`5681`, :pull:`5682`). By `Tomas Chor `_. -- Add a option to disable the use of ``bottleneck`` (:pull:`5560`) +- Add an option to disable the use of ``bottleneck`` (:pull:`5560`) By `Justus Magin `_. - Added ``**kwargs`` argument to :py:meth:`open_rasterio` to access overviews (:issue:`3269`). By `Pushkar Kopparla `_. - Added ``storage_options`` argument to :py:meth:`to_zarr` (:issue:`5601`). By `Ray Bell `_, `Zachary Blackwood `_ and `Nathan Lis `_. - +- Histogram plots are set with a title displaying the scalar coords if any, similarly to the other plots (:issue:`5791`, :pull:`5792`). + By `Maxime Liquet `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -55,16 +57,40 @@ Breaking changes Deprecations ~~~~~~~~~~~~ +- Deprecate :py:func:`open_rasterio` (:issue:`4697`, :pull:`5808`). + By `Alan Snow `_. +- Set the default argument for `roll_coords` to `False` for :py:meth:`DataArray.roll` + and :py:meth:`Dataset.roll`. (:pull:`5653`) + By `Tom Nicholas `_. +- :py:meth:`xarray.open_mfdataset` will now error instead of warn when a value for ``concat_dim`` is + passed alongside ``combine='by_coords'``. + By `Tom Nicholas `_. Bug fixes ~~~~~~~~~ +- Fix ZeroDivisionError from saving dask array with empty dimension (:issue: `5741`). + By `Joseph K Aicher `_. +- Fixed performance bug where ``cftime`` import attempted within various core operations if ``cftime`` not + installed (:pull:`5640`). + By `Luke Sewell `_ +- When a custom engine was used in :py:func:`~xarray.open_dataset` the engine + wasn't initialized properly, causing missing argument errors or inconsistent + method signatures. (:pull:`5684`) + By `Jimmy Westling `_. - Numbers are properly formatted in a plot's title (:issue:`5788`, :pull:`5789`). By `Maxime Liquet `_. +- ``open_mfdataset()`` now accepts a single ``pathlib.Path`` object (:issue: `5881`). + By `Panos Mavrogiorgos `_. Documentation ~~~~~~~~~~~~~ +- Users are instructed to try ``use_cftime=True`` if a ``TypeError`` occurs when combining datasets and one of the types involved is a subclass of ``cftime.datetime`` (:pull:`5776`). + By `Zeb Nicholls `_. +- A clearer error is now raised if a user attempts to assign a Dataset to a single key of + another Dataset. (:pull:`5839`) + By `Tom Nicholas `_. Internal Changes ~~~~~~~~~~~~~~~~ @@ -82,6 +108,15 @@ Internal Changes By `Jimmy Westling `_. - Use isort's `float_to_top` config. (:pull:`5695`). By `Maximilian Roos `_. +- Remove use of the deprecated ``kind`` argument in + :py:meth:`pandas.Index.get_slice_bound` inside :py:class:`xarray.CFTimeIndex` + tests (:pull:`5723`). By `Spencer Clark `_. +- Refactor `xarray.core.duck_array_ops` to no longer special-case dispatching to + dask versions of functions when acting on dask arrays, instead relying numpy + and dask's adherence to NEP-18 to dispatch automatically. (:pull:`5571`) + By `Tom Nicholas `_. +- Add an ASV benchmark CI and improve performance of the benchmarks (:pull:`5796`) + By `Jimmy Westling `_. .. _whats-new.0.19.0: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2c9b25f860f..0cde8ab7315 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1,5 +1,4 @@ import os -import warnings from glob import glob from io import BytesIO from numbers import Number @@ -860,13 +859,14 @@ def open_mfdataset( paths = [fs.get_mapper(path) for path in paths] elif is_remote_uri(paths): raise ValueError( - "cannot do wild-card matching for paths that are remote URLs: " - "{!r}. Instead, supply paths as an explicit list of strings.".format( - paths - ) + "cannot do wild-card matching for paths that are remote URLs " + f"unless engine='zarr' is specified. Got paths: {paths}. " + "Instead, supply paths as an explicit list of strings." ) else: paths = sorted(glob(_normalize_path(paths))) + elif isinstance(paths, os.PathLike): + paths = [os.fspath(paths)] else: paths = [str(p) if isinstance(p, Path) else p for p in paths] @@ -885,15 +885,11 @@ def open_mfdataset( list(combined_ids_paths.keys()), list(combined_ids_paths.values()), ) - - # TODO raise an error instead of a warning after v0.19 elif combine == "by_coords" and concat_dim is not None: - warnings.warn( + raise ValueError( "When combine='by_coords', passing a value for `concat_dim` has no " - "effect. This combination will raise an error in future. To manually " - "combine along a specific dimension you should instead specify " - "combine='nested' along with a value for `concat_dim`.", - DeprecationWarning, + "effect. To manually combine along a specific dimension you should " + "instead specify combine='nested' along with a value for `concat_dim`.", ) open_kwargs = dict(engine=engine, chunks=chunks or {}, **kwargs) @@ -1327,6 +1323,11 @@ def to_zarr( See `Dataset.to_zarr` for full API docs. """ + # Load empty arrays to avoid bug saving zero length dimensions (Issue #5741) + for v in dataset.variables.values(): + if v.size == 0: + v.load() + # expand str and Path arguments store = _normalize_path(store) chunk_store = _normalize_path(chunk_store) diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index 08c1bec8325..57795865821 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -158,7 +158,7 @@ def get_backend(engine): ) backend = engines[engine] elif isinstance(engine, type) and issubclass(engine, BackendEntrypoint): - backend = engine + backend = engine() else: raise TypeError( ( diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 1891fac8668..f34240e5e35 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -170,7 +170,13 @@ def open_rasterio( lock=None, **kwargs, ): - """Open a file with rasterio (experimental). + """Open a file with rasterio. + + .. deprecated:: 0.19.1 + + Deprecated in favor of rioxarray. + For information about transitioning, see: + https://corteva.github.io/rioxarray/stable/getting_started/getting_started.html This should work with any file that rasterio can open (most often: geoTIFF). The x and y coordinates are generated automatically from the @@ -252,6 +258,13 @@ def open_rasterio( data : DataArray The newly created DataArray. """ + warnings.warn( + "open_rasterio is Deprecated in favor of rioxarray. " + "For information about transitioning, see: " + "https://corteva.github.io/rioxarray/stable/getting_started/getting_started.html", + DeprecationWarning, + stacklevel=2, + ) import rasterio from rasterio.vrt import WarpedVRT diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index c031bffb2cd..c080f19ef73 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -52,12 +52,15 @@ from .cftimeindex import CFTimeIndex, _parse_iso8601_with_reso from .times import format_cftime_datetime +try: + import cftime +except ImportError: + cftime = None + def get_date_type(calendar): """Return the cftime date type for a given calendar name.""" - try: - import cftime - except ImportError: + if cftime is None: raise ImportError("cftime is required for dates with non-standard calendars") else: calendars = { @@ -99,7 +102,8 @@ def __add__(self, other): return self.__apply__(other) def __sub__(self, other): - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") if isinstance(other, cftime.datetime): raise TypeError("Cannot subtract a cftime.datetime from a time offset.") @@ -221,7 +225,8 @@ def _adjust_n_years(other, n, month, reference_day): def _shift_month(date, months, day_option="start"): """Shift the date to a month start or end a given number of months away.""" - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") delta_year = (date.month + months) // 12 month = (date.month + months) % 12 @@ -378,7 +383,8 @@ def onOffset(self, date): return mod_month == 0 and date.day == self._get_offset_day(date) def __sub__(self, other): - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") if isinstance(other, cftime.datetime): raise TypeError("Cannot subtract cftime.datetime from offset.") @@ -463,7 +469,8 @@ def __apply__(self, other): return _shift_month(other, months, self._day_option) def __sub__(self, other): - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") if isinstance(other, cftime.datetime): raise TypeError("Cannot subtract cftime.datetime from offset.") @@ -688,7 +695,8 @@ def to_offset(freq): def to_cftime_datetime(date_str_or_date, calendar=None): - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") if isinstance(date_str_or_date, str): if calendar is None: @@ -724,7 +732,8 @@ def _maybe_normalize_date(date, normalize): def _generate_linear_range(start, end, periods): """Generate an equally-spaced sequence of cftime.datetime objects between and including two dates (whose length equals the number of periods).""" - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") total_seconds = (end - start).total_seconds() values = np.linspace(0.0, total_seconds, periods, endpoint=True) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 783fe8d04d9..c0750069c23 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -54,6 +54,12 @@ from ..core.options import OPTIONS from .times import _STANDARD_CALENDARS, cftime_to_nptime, infer_calendar_name +try: + import cftime +except ImportError: + cftime = None + + # constants for cftimeindex.repr CFTIME_REPR_LENGTH = 19 ITEMS_IN_REPR_MAX_ELSE_ELLIPSIS = 100 @@ -114,7 +120,8 @@ def parse_iso8601_like(datetime_string): def _parse_iso8601_with_reso(date_type, timestr): - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") default = date_type(1, 1, 1) result = parse_iso8601_like(timestr) @@ -189,7 +196,8 @@ def _field_accessor(name, docstring=None, min_cftime_version="0.0"): """Adapted from pandas.tseries.index._field_accessor""" def f(self, min_cftime_version=min_cftime_version): - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") version = cftime.__version__ @@ -215,7 +223,8 @@ def get_date_type(self): def assert_all_valid_date_type(data): - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") if len(data) > 0: sample = data[0] diff --git a/xarray/coding/times.py b/xarray/coding/times.py index f62a3961207..2b2d25f1666 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -22,6 +22,11 @@ unpack_for_encoding, ) +try: + import cftime +except ImportError: + cftime = None + # standard calendars recognized by cftime _STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"} @@ -164,8 +169,8 @@ def _decode_cf_datetime_dtype(data, units, calendar, use_cftime): def _decode_datetime_with_cftime(num_dates, units, calendar): - import cftime - + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") return np.asarray( cftime.num2date(num_dates, units, calendar, only_use_cftime_datetimes=True) ) @@ -414,7 +419,8 @@ def _encode_datetime_with_cftime(dates, units, calendar): This method is more flexible than xarray's parsing using datetime64[ns] arrays but also slower because it loops over each element. """ - import cftime + if cftime is None: + raise ModuleNotFoundError("No module named 'cftime'") if np.issubdtype(dates.dtype, np.datetime64): # numpy's broken datetime conversion only works for us precision diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 7e1565e50de..56956a57e02 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,6 +1,7 @@ import itertools import warnings from collections import Counter +from typing import Iterable, Sequence, Union import pandas as pd @@ -50,11 +51,26 @@ def _ensure_same_types(series, dim): if series.dtype == object: types = set(series.map(type)) if len(types) > 1: + try: + import cftime + + cftimes = any(issubclass(t, cftime.datetime) for t in types) + except ImportError: + cftimes = False + types = ", ".join(t.__name__ for t in types) - raise TypeError( + + error_msg = ( f"Cannot combine along dimension '{dim}' with mixed types." f" Found: {types}." ) + if cftimes: + error_msg = ( + f"{error_msg} If importing data directly from a file then " + f"setting `use_cftime=True` may fix this issue." + ) + + raise TypeError(error_msg) def _infer_concat_order_from_coords(datasets): @@ -354,16 +370,23 @@ def _nested_combine( return combined +# Define type for arbitrarily-nested list of lists recursively +# Currently mypy cannot handle this but other linters can (https://stackoverflow.com/a/53845083/3154101) +DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]] # type: ignore + + def combine_nested( - datasets, - concat_dim, - compat="no_conflicts", - data_vars="all", - coords="different", - fill_value=dtypes.NA, - join="outer", - combine_attrs="drop", -): + datasets: DATASET_HYPERCUBE, + concat_dim: Union[ + str, DataArray, None, Sequence[Union[str, "DataArray", pd.Index, None]] + ], + compat: str = "no_conflicts", + data_vars: str = "all", + coords: str = "different", + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "drop", +) -> Dataset: """ Explicitly combine an N-dimensional grid of datasets into one by using a succession of concat and merge operations along each dimension of the grid. @@ -636,16 +659,17 @@ def _combine_single_variable_hypercube( # TODO remove empty list default param after version 0.21, see PR4696 def combine_by_coords( - data_objects=[], - compat="no_conflicts", - data_vars="all", - coords="different", - fill_value=dtypes.NA, - join="outer", - combine_attrs="no_conflicts", - datasets=None, -): + data_objects: Sequence[Union[Dataset, DataArray]] = [], + compat: str = "no_conflicts", + data_vars: str = "all", + coords: str = "different", + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "no_conflicts", + datasets: Sequence[Dataset] = None, +) -> Union[Dataset, DataArray]: """ + Attempt to auto-magically combine the given datasets (or data arrays) into one by using dimension coordinates. @@ -740,7 +764,7 @@ def combine_by_coords( Returns ------- - combined : xarray.Dataset + combined : xarray.Dataset or xarray.DataArray See also -------- diff --git a/xarray/core/common.py b/xarray/core/common.py index 0f2b58d594a..2c5d7900ef8 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -31,6 +31,11 @@ from .rolling_exp import RollingExp from .utils import Frozen, either_dict_or_kwargs, is_scalar +try: + import cftime +except ImportError: + cftime = None + # Used as a sentinel value to indicate a all dimensions ALL_DIMS = ... @@ -1820,9 +1825,7 @@ def is_np_timedelta_like(dtype: DTypeLike) -> bool: def _contains_cftime_datetimes(array) -> bool: """Check if an array contains cftime.datetime objects""" - try: - from cftime import datetime as cftime_datetime - except ImportError: + if cftime is None: return False else: if array.dtype == np.dtype("O") and array.size > 0: @@ -1831,7 +1834,7 @@ def _contains_cftime_datetimes(array) -> bool: sample = sample.compute() if isinstance(sample, np.ndarray): sample = sample.item() - return isinstance(sample, cftime_datetime) + return isinstance(sample, cftime.datetime) else: return False diff --git a/xarray/core/computation.py b/xarray/core/computation.py index bbaae1f5b36..7f60da7e1b2 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1529,7 +1529,7 @@ def dot(*arrays, dims=None, **kwargs): join=join, dask="allowed", ) - return result.transpose(*[d for d in all_dims if d in result.dims]) + return result.transpose(*all_dims, missing_dims="ignore") def where(cond, x, y): diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 30fc478d26e..ed8b393628d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3169,11 +3169,14 @@ def shift( fill_value: Any = dtypes.NA, **shifts_kwargs: int, ) -> "DataArray": - """Shift this array by an offset along one or more dimensions. + """Shift this DataArray by an offset along one or more dimensions. - Only the data is moved; coordinates stay in place. Values shifted from - beyond array bounds are replaced by NaN. This is consistent with the - behavior of ``shift`` in pandas. + Only the data is moved; coordinates stay in place. This is consistent + with the behavior of ``shift`` in pandas. + + Values shifted from beyond array bounds will appear at one end of + each dimension, which are filled according to `fill_value`. For periodic + offsets instead see `roll`. Parameters ---------- @@ -3212,12 +3215,15 @@ def shift( def roll( self, - shifts: Mapping[Any, int] = None, - roll_coords: bool = None, + shifts: Mapping[Hashable, int] = None, + roll_coords: bool = False, **shifts_kwargs: int, ) -> "DataArray": """Roll this array by an offset along one or more dimensions. + Unlike shift, roll treats the given dimensions as periodic, so will not + create any missing values to be filled. + Unlike shift, roll may rotate all variables, including coordinates if specified. The direction of rotation is consistent with :py:func:`numpy.roll`. @@ -3228,12 +3234,9 @@ def roll( Integer offset to rotate each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. - roll_coords : bool - Indicates whether to roll the coordinates by the offset - The current default of roll_coords (None, equivalent to True) is - deprecated and will change to False in a future version. - Explicitly pass roll_coords to silence the warning. - **shifts_kwargs + roll_coords : bool, default: False + Indicates whether to roll the coordinates by the offset too. + **shifts_kwargs : {dim: offset, ...}, optional The keyword arguments form of ``shifts``. One of shifts or shifts_kwargs must be provided. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5a00539346c..550c3587aa6 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1097,7 +1097,7 @@ def _replace( coord_names: Set[Hashable] = None, dims: Dict[Any, int] = None, attrs: Union[Dict[Hashable, Any], None, Default] = _default, - indexes: Union[Dict[Any, Index], None, Default] = _default, + indexes: Union[Dict[Hashable, Index], None, Default] = _default, encoding: Union[dict, None, Default] = _default, inplace: bool = False, ) -> "Dataset": @@ -1557,6 +1557,11 @@ def __setitem__(self, key: Union[Hashable, List[Hashable], Mapping], value) -> N self.update(dict(zip(key, value))) else: + if isinstance(value, Dataset): + raise TypeError( + "Cannot assign a Dataset to a single key - only a DataArray or Variable object can be stored under" + "a single key." + ) self.update({key: value}) def _setitem_check(self, key, value): @@ -5866,12 +5871,22 @@ def diff(self, dim, n=1, label="upper"): else: return difference - def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): + def shift( + self, + shifts: Mapping[Hashable, int] = None, + fill_value: Any = dtypes.NA, + **shifts_kwargs: int, + ) -> "Dataset": + """Shift this dataset by an offset along one or more dimensions. Only data variables are moved; coordinates stay in place. This is consistent with the behavior of ``shift`` in pandas. + Values shifted from beyond array bounds will appear at one end of + each dimension, which are filled according to `fill_value`. For periodic + offsets instead see `roll`. + Parameters ---------- shifts : mapping of hashable to int @@ -5926,32 +5941,37 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): return self._replace(variables) - def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): + def roll( + self, + shifts: Mapping[Hashable, int] = None, + roll_coords: bool = False, + **shifts_kwargs: int, + ) -> "Dataset": """Roll this dataset by an offset along one or more dimensions. - Unlike shift, roll may rotate all variables, including coordinates + Unlike shift, roll treats the given dimensions as periodic, so will not + create any missing values to be filled. + + Also unlike shift, roll may rotate all variables, including coordinates if specified. The direction of rotation is consistent with :py:func:`numpy.roll`. Parameters ---------- - shifts : dict, optional + shifts : mapping of hashable to int, optional A dict with keys matching dimensions and values given by integers to rotate each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. - roll_coords : bool - Indicates whether to roll the coordinates by the offset - The current default of roll_coords (None, equivalent to True) is - deprecated and will change to False in a future version. - Explicitly pass roll_coords to silence the warning. + roll_coords : bool, default: False + Indicates whether to roll the coordinates by the offset too. **shifts_kwargs : {dim: offset, ...}, optional The keyword arguments form of ``shifts``. One of shifts or shifts_kwargs must be provided. + Returns ------- rolled : Dataset - Dataset with the same coordinates and attributes but rolled - variables. + Dataset with the same attributes but rolled data and coordinates. See Also -------- @@ -5959,47 +5979,49 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): Examples -------- - >>> ds = xr.Dataset({"foo": ("x", list("abcde"))}) + >>> ds = xr.Dataset({"foo": ("x", list("abcde"))}, coords={"x": np.arange(5)}) >>> ds.roll(x=2) Dimensions: (x: 5) - Dimensions without coordinates: x + Coordinates: + * x (x) int64 0 1 2 3 4 + Data variables: + foo (x) >> ds.roll(x=2, roll_coords=True) + + Dimensions: (x: 5) + Coordinates: + * x (x) int64 3 4 0 1 2 Data variables: foo (x) ={requires_dask}") - else: - wrapped = getattr(eager_module, name) - return wrapped(*args, **kwargs) - else: - - def f(*args, **kwargs): - return getattr(eager_module, name)(*args, **kwargs) + def f(*args, **kwargs): + if any(is_duck_dask_array(a) for a in args): + wrapped = getattr(dask_module, name) + else: + wrapped = getattr(eager_module, name) + return wrapped(*args, **kwargs) return f @@ -72,16 +65,40 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): raise NotImplementedError(msg % func_name) -around = _dask_or_eager_func("around") -isclose = _dask_or_eager_func("isclose") - +# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18 +pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module=dask_array) -isnat = np.isnat -isnan = _dask_or_eager_func("isnan") -zeros_like = _dask_or_eager_func("zeros_like") - - -pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd) +# np.around has failing doctests, overwrite it so they pass: +# https://github.com/numpy/numpy/issues/19759 +around.__doc__ = str.replace( + around.__doc__ or "", + "array([0., 2.])", + "array([0., 2.])", +) +around.__doc__ = str.replace( + around.__doc__ or "", + "array([0., 2.])", + "array([0., 2.])", +) +around.__doc__ = str.replace( + around.__doc__ or "", + "array([0.4, 1.6])", + "array([0.4, 1.6])", +) +around.__doc__ = str.replace( + around.__doc__ or "", + "array([0., 2., 2., 4., 4.])", + "array([0., 2., 2., 4., 4.])", +) +around.__doc__ = str.replace( + around.__doc__ or "", + ( + ' .. [2] "How Futile are Mindless Assessments of\n' + ' Roundoff in Floating-Point Computation?", William Kahan,\n' + " https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf\n" + ), + "", +) def isnull(data): @@ -114,21 +131,10 @@ def notnull(data): return ~isnull(data) -transpose = _dask_or_eager_func("transpose") -_where = _dask_or_eager_func("where", array_args=slice(3)) -isin = _dask_or_eager_func("isin", array_args=slice(2)) -take = _dask_or_eager_func("take") -broadcast_to = _dask_or_eager_func("broadcast_to") -pad = _dask_or_eager_func("pad", dask_module=dask_array_compat) - -_concatenate = _dask_or_eager_func("concatenate", list_of_args=True) -_stack = _dask_or_eager_func("stack", list_of_args=True) - -array_all = _dask_or_eager_func("all") -array_any = _dask_or_eager_func("any") - -tensordot = _dask_or_eager_func("tensordot", array_args=slice(2)) -einsum = _dask_or_eager_func("einsum", array_args=slice(1, None)) +# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed +masked_invalid = _dask_or_eager_func( + "masked_invalid", eager_module=np.ma, dask_module=getattr(dask_array, "ma", None) +) def gradient(x, coord, axis, edge_order): @@ -166,11 +172,6 @@ def cumulative_trapezoid(y, x, axis): return cumsum(integrand, axis=axis, skipna=False) -masked_invalid = _dask_or_eager_func( - "masked_invalid", eager_module=np.ma, dask_module=getattr(dask_array, "ma", None) -) - - def astype(data, dtype, **kwargs): if ( isinstance(data, sparse_array_type) @@ -317,9 +318,7 @@ def _ignore_warnings_if(condition): yield -def _create_nan_agg_method( - name, dask_module=dask_array, coerce_strings=False, invariant_0d=False -): +def _create_nan_agg_method(name, coerce_strings=False, invariant_0d=False): from . import nanops def f(values, axis=None, skipna=None, **kwargs): @@ -344,7 +343,8 @@ def f(values, axis=None, skipna=None, **kwargs): else: if name in ["sum", "prod"]: kwargs.pop("min_count", None) - func = _dask_or_eager_func(name, dask_module=dask_module) + + func = getattr(np, name) try: with warnings.catch_warnings(): @@ -378,9 +378,7 @@ def f(values, axis=None, skipna=None, **kwargs): std.numeric_only = True var = _create_nan_agg_method("var") var.numeric_only = True -median = _create_nan_agg_method( - "median", dask_module=dask_array_compat, invariant_0d=True -) +median = _create_nan_agg_method("median", invariant_0d=True) median.numeric_only = True prod = _create_nan_agg_method("prod", invariant_0d=True) prod.numeric_only = True @@ -389,7 +387,6 @@ def f(values, axis=None, skipna=None, **kwargs): cumprod_1d.numeric_only = True cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True) cumsum_1d.numeric_only = True -unravel_index = _dask_or_eager_func("unravel_index") _mean = _create_nan_agg_method("mean", invariant_0d=True) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 95b6ccaad30..1ded35264f4 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -5,6 +5,7 @@ Dict, Hashable, Iterable, + Iterator, Mapping, Optional, Sequence, @@ -449,7 +450,7 @@ class Indexes(collections.abc.Mapping): __slots__ = ("_indexes",) - def __init__(self, indexes): + def __init__(self, indexes: Mapping[Any, Union[pd.Index, Index]]) -> None: """Not for public consumption. Parameters @@ -459,7 +460,7 @@ def __init__(self, indexes): """ self._indexes = indexes - def __iter__(self): + def __iter__(self) -> Iterator[pd.Index]: return iter(self._indexes) def __len__(self): @@ -468,7 +469,7 @@ def __len__(self): def __contains__(self, key): return key in self._indexes - def __getitem__(self, key): + def __getitem__(self, key) -> pd.Index: return self._indexes[key] def __repr__(self): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 48106bff289..c1a4d629f97 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -3,14 +3,7 @@ import numpy as np from . import dtypes, nputils, utils -from .duck_array_ops import ( - _dask_or_eager_func, - count, - fillna, - isnull, - where, - where_method, -) +from .duck_array_ops import count, fillna, isnull, where, where_method from .pycompat import dask_array_type try: @@ -53,7 +46,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): """ valid_count = count(value, axis=axis) value = fillna(value, fill_value) - data = _dask_or_eager_func(func)(value, axis=axis, **kwargs) + data = getattr(np, func)(value, axis=axis, **kwargs) # TODO This will evaluate dask arrays and might be costly. if (valid_count == 0).any(): @@ -111,7 +104,7 @@ def nanargmax(a, axis=None): def nansum(a, axis=None, dtype=None, out=None, min_count=None): a, mask = _replace_nan(a, 0) - result = _dask_or_eager_func("sum")(a, axis=axis, dtype=dtype) + result = np.sum(a, axis=axis, dtype=dtype) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) else: @@ -120,7 +113,7 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None): def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs): """In house nanmean. ddof argument will be used in _nanvar method""" - from .duck_array_ops import _dask_or_eager_func, count, fillna, where_method + from .duck_array_ops import count, fillna, where_method valid_count = count(value, axis=axis) value = fillna(value, 0) @@ -129,7 +122,7 @@ def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs): if dtype is None and value.dtype.kind == "O": dtype = value.dtype if value.dtype.kind in ["cf"] else float - data = _dask_or_eager_func("sum")(value, axis=axis, dtype=dtype, **kwargs) + data = np.sum(value, axis=axis, dtype=dtype, **kwargs) data = data / (valid_count - ddof) return where_method(data, valid_count != 0) @@ -155,7 +148,7 @@ def nanmedian(a, axis=None, out=None): # possibly blow memory if axis is not None and len(np.atleast_1d(axis)) == a.ndim: axis = None - return _dask_or_eager_func("nanmedian", eager_module=nputils)(a, axis=axis) + return nputils.nanmedian(a, axis=axis) def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs): @@ -170,20 +163,16 @@ def nanvar(a, axis=None, dtype=None, out=None, ddof=0): if a.dtype.kind == "O": return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof) - return _dask_or_eager_func("nanvar", eager_module=nputils)( - a, axis=axis, dtype=dtype, ddof=ddof - ) + return nputils.nanvar(a, axis=axis, dtype=dtype, ddof=ddof) def nanstd(a, axis=None, dtype=None, out=None, ddof=0): - return _dask_or_eager_func("nanstd", eager_module=nputils)( - a, axis=axis, dtype=dtype, ddof=ddof - ) + return nputils.nanstd(a, axis=axis, dtype=dtype, ddof=ddof) def nanprod(a, axis=None, dtype=None, out=None, min_count=None): a, mask = _replace_nan(a, 1) - result = _dask_or_eager_func("nanprod")(a, axis=axis, dtype=dtype, out=out) + result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) else: @@ -191,12 +180,8 @@ def nanprod(a, axis=None, dtype=None, out=None, min_count=None): def nancumsum(a, axis=None, dtype=None, out=None): - return _dask_or_eager_func("nancumsum", eager_module=nputils)( - a, axis=axis, dtype=dtype - ) + return nputils.nancumsum(a, axis=axis, dtype=dtype) def nancumprod(a, axis=None, dtype=None, out=None): - return _dask_or_eager_func("nancumprod", eager_module=nputils)( - a, axis=axis, dtype=dtype - ) + return nputils.nancumprod(a, axis=axis, dtype=dtype) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index c1aedd570bc..7288a368e47 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -12,6 +12,7 @@ _process_cmap_cbar_kwargs, get_axis, label_from_attrs, + plt, ) # copied from seaborn @@ -134,8 +135,7 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None) # copied from seaborn def _parse_size(data, norm): - - import matplotlib as mpl + mpl = plt.matplotlib if data is None: return None @@ -544,8 +544,6 @@ def quiver(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`. """ - import matplotlib as mpl - if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for quiver plots.") @@ -560,7 +558,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = mpl.colors.Normalize( + cmap_params["norm"] = plt.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) @@ -576,8 +574,6 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`. """ - import matplotlib as mpl - if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for streamplot plots.") @@ -613,7 +609,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = mpl.colors.Normalize( + cmap_params["norm"] = plt.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 28dd82e76f5..b384dea0571 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,8 +9,8 @@ _get_nice_quiver_magnitude, _infer_xy_labels, _process_cmap_cbar_kwargs, - import_matplotlib_pyplot, label_from_attrs, + plt, ) # Overrides axes.labelsize, xtick.major.size, ytick.major.size @@ -116,8 +116,6 @@ def __init__( """ - plt = import_matplotlib_pyplot() - # Handle corner case of nonunique coordinates rep_col = col is not None and not data[col].to_index().is_unique rep_row = row is not None and not data[row].to_index().is_unique @@ -519,10 +517,8 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar self: FacetGrid object """ - import matplotlib as mpl - if size is None: - size = mpl.rcParams["axes.labelsize"] + size = plt.rcParams["axes.labelsize"] nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template) @@ -619,8 +615,6 @@ def map(self, func, *args, **kwargs): self : FacetGrid object """ - plt = import_matplotlib_pyplot() - for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): if namedict is not None: data = self.data.loc[namedict] diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index e20b6568e79..1e1e59e2f71 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -29,9 +29,9 @@ _resolve_intervals_2dplot, _update_axes, get_axis, - import_matplotlib_pyplot, label_from_attrs, legend_elements, + plt, ) # copied from seaborn @@ -83,8 +83,6 @@ def _parse_size(data, norm, width): If the data is categorical, normalize it to numbers. """ - plt = import_matplotlib_pyplot() - if data is None: return None @@ -556,7 +554,7 @@ def hist( primitive = ax.hist(no_nan, **kwargs) - ax.set_title("Histogram") + ax.set_title(darray._title_for_slice()) ax.set_xlabel(label_from_attrs(darray)) _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) @@ -682,8 +680,6 @@ def scatter( **kwargs : optional Additional keyword arguments to matplotlib """ - plt = import_matplotlib_pyplot() - # Handle facetgrids first if row or col: allargs = locals().copy() @@ -1111,8 +1107,6 @@ def newplotfunc( allargs["plotfunc"] = globals()[plotfunc.__name__] return _easy_facetgrid(darray, kind="dataarray", **allargs) - plt = import_matplotlib_pyplot() - if ( plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index af5859c1f14..6fbbe9d4bca 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -19,6 +19,12 @@ except ImportError: nc_time_axis_available = False + +try: + import cftime +except ImportError: + cftime = None + ROBUST_PERCENTILE = 2.0 @@ -41,6 +47,12 @@ def import_matplotlib_pyplot(): return plt +try: + plt = import_matplotlib_pyplot() +except ImportError: + plt = None + + def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -58,7 +70,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled): """ Build a discrete colormap and normalization of the data. """ - import matplotlib as mpl + mpl = plt.matplotlib if len(levels) == 1: levels = [levels[0], levels[0]] @@ -109,8 +121,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled): def _color_palette(cmap, n_colors): - import matplotlib.pyplot as plt - from matplotlib.colors import ListedColormap + ListedColormap = plt.matplotlib.colors.ListedColormap colors_i = np.linspace(0, 1.0, n_colors) if isinstance(cmap, (list, tuple)): @@ -171,7 +182,7 @@ def _determine_cmap_params( cmap_params : dict Use depends on the type of the plotting function """ - import matplotlib as mpl + mpl = plt.matplotlib if isinstance(levels, Iterable): levels = sorted(levels) @@ -279,13 +290,13 @@ def _determine_cmap_params( levels = np.asarray([(vmin + vmax) / 2]) else: # N in MaxNLocator refers to bins, not ticks - ticker = mpl.ticker.MaxNLocator(levels - 1) + ticker = plt.MaxNLocator(levels - 1) levels = ticker.tick_values(vmin, vmax) vmin, vmax = levels[0], levels[-1] # GH3734 if vmin == vmax: - vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax) + vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax) if extend is None: extend = _determine_extend(calc_data, vmin, vmax) @@ -415,10 +426,7 @@ def _assert_valid_xy(darray, xy, name): def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): - try: - import matplotlib as mpl - import matplotlib.pyplot as plt - except ImportError: + if plt is None: raise ImportError("matplotlib is required for plot.utils.get_axis") if figsize is not None: @@ -431,7 +439,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): if ax is not None: raise ValueError("cannot provide both `size` and `ax` arguments") if aspect is None: - width, height = mpl.rcParams["figure.figsize"] + width, height = plt.rcParams["figure.figsize"] aspect = width / height figsize = (size * aspect, size) _, ax = plt.subplots(figsize=figsize) @@ -448,9 +456,6 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): def _maybe_gca(**kwargs): - - import matplotlib.pyplot as plt - # can call gcf unconditionally: either it exists or would be created by plt.axes f = plt.gcf() @@ -628,13 +633,11 @@ def _ensure_plottable(*args): np.str_, ] other_types = [datetime] - try: - import cftime - - cftime_datetime = [cftime.datetime] - except ImportError: - cftime_datetime = [] - other_types = other_types + cftime_datetime + if cftime is not None: + cftime_datetime_types = [cftime.datetime] + other_types = other_types + cftime_datetime_types + else: + cftime_datetime_types = [] for x in args: if not ( _valid_numpy_subdtype(np.array(x), numpy_types) @@ -647,7 +650,7 @@ def _ensure_plottable(*args): f"pandas.Interval. Received data of type {np.array(x).dtype} instead." ) if ( - _valid_other_type(np.array(x), cftime_datetime) + _valid_other_type(np.array(x), cftime_datetime_types) and not nc_time_axis_available ): raise ImportError( @@ -908,9 +911,7 @@ def _process_cmap_cbar_kwargs( def _get_nice_quiver_magnitude(u, v): - import matplotlib as mpl - - ticker = mpl.ticker.MaxNLocator(3) + ticker = plt.MaxNLocator(3) mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy())) magnitude = ticker.tick_values(0, mean)[-2] return magnitude @@ -985,7 +986,7 @@ def legend_elements( """ import warnings - import matplotlib as mpl + mpl = plt.matplotlib mlines = mpl.lines @@ -1122,7 +1123,6 @@ def _legend_add_subtitle(handles, labels, text, func): def _adjust_legend_subtitles(legend): """Make invisible-handle "subtitles" entries look more like titles.""" - plt = import_matplotlib_pyplot() # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None) diff --git a/xarray/tests/data/example.ict b/xarray/tests/data/example.ict index 41bbfeb996c..a33e71a9a81 100644 --- a/xarray/tests/data/example.ict +++ b/xarray/tests/data/example.ict @@ -1,13 +1,13 @@ -27, 1001 +29, 1001 Henderson, Barron U.S. EPA Example file with artificial data JUST_A_TEST 1, 1 -2018, 04, 27, 2018, 04, 27 +2018, 04, 27 2018, 04, 27 0 Start_UTC -7 +5 1, 1, 1, 1, 1 -9999, -9999, -9999, -9999, -9999 lat, degrees_north @@ -16,7 +16,9 @@ elev, meters TEST_ppbv, ppbv TESTM_ppbv, ppbv 0 -8 +9 +INDEPENDENT_VARIABLE_DEFINITION: Start_UTC +INDEPENDENT_VARIABLE_UNITS: Start_UTC ULOD_FLAG: -7777 ULOD_VALUE: N/A LLOD_FLAG: -8888 diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index f42f7f530d4..7657e42ff66 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2186,6 +2186,16 @@ def test_to_zarr_append_compute_false_roundtrip(self): with self.open(store) as actual: assert_identical(xr.concat([ds, ds_to_append], dim="time"), actual) + @pytest.mark.parametrize("chunk", [False, True]) + def test_save_emptydim(self, chunk): + if chunk and not has_dask: + pytest.skip("requires dask") + ds = Dataset({"x": (("a", "b"), np.empty((5, 0))), "y": ("a", [1, 2, 5, 8, 9])}) + if chunk: + ds = ds.chunk({}) # chunk dataset to save dask array + with self.roundtrip(ds) as ds_reload: + assert_identical(ds, ds_reload) + @pytest.mark.parametrize("consolidated", [False, True]) @pytest.mark.parametrize("compute", [False, True]) @pytest.mark.parametrize("use_dask", [False, True]) @@ -3029,6 +3039,14 @@ def test_open_mfdataset_manyfiles( assert_identical(original, actual) +@requires_netCDF4 +@requires_dask +def test_open_mfdataset_can_open_path_objects(): + dataset = os.path.join(os.path.dirname(__file__), "data", "example_1.nc") + with open_mfdataset(Path(dataset)) as actual: + assert isinstance(actual, Dataset) + + @requires_netCDF4 @requires_dask def test_open_mfdataset_list_attr(): @@ -3497,7 +3515,6 @@ def test_open_mfdataset_auto_combine(self): with open_mfdataset([tmp2, tmp1], combine="by_coords") as actual: assert_identical(original, actual) - # TODO check for an error instead of a warning once deprecated def test_open_mfdataset_raise_on_bad_combine_args(self): # Regression test for unhelpful error shown in #5230 original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)}) @@ -3505,9 +3522,7 @@ def test_open_mfdataset_raise_on_bad_combine_args(self): with create_tmp_file() as tmp2: original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with pytest.warns( - DeprecationWarning, match="`concat_dim` has no effect" - ): + with pytest.raises(ValueError, match="`concat_dim` has no effect"): open_mfdataset([tmp1, tmp2], concat_dim="x") @pytest.mark.xfail(reason="mfdataset loses encoding currently.") @@ -3957,7 +3972,7 @@ def myatts(**attrs): "coords": {}, "attrs": { "fmt": "1001", - "n_header_lines": 27, + "n_header_lines": 29, "PI_NAME": "Henderson, Barron", "ORGANIZATION_NAME": "U.S. EPA", "SOURCE_DESCRIPTION": "Example file with artificial data", @@ -3966,7 +3981,9 @@ def myatts(**attrs): "SDATE": "2018, 04, 27", "WDATE": "2018, 04, 27", "TIME_INTERVAL": "0", + "INDEPENDENT_VARIABLE_DEFINITION": "Start_UTC", "INDEPENDENT_VARIABLE": "Start_UTC", + "INDEPENDENT_VARIABLE_UNITS": "Start_UTC", "ULOD_FLAG": "-7777", "ULOD_VALUE": "N/A", "LLOD_FLAG": "-8888", @@ -4215,7 +4232,7 @@ class TestRasterio: def test_serialization(self): with create_tmp_geotiff(additional_attrs={}) as (tmp_file, expected): # Write it to a netcdf and read again (roundtrip) - with xr.open_rasterio(tmp_file) as rioda: + with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: with create_tmp_file(suffix=".nc") as tmp_nc_file: rioda.to_netcdf(tmp_nc_file) with xr.open_dataarray(tmp_nc_file) as ncds: @@ -4223,7 +4240,7 @@ def test_serialization(self): def test_utm(self): with create_tmp_geotiff() as (tmp_file, expected): - with xr.open_rasterio(tmp_file) as rioda: + with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: assert_allclose(rioda, expected) assert rioda.attrs["scales"] == (1.0, 1.0, 1.0) assert rioda.attrs["offsets"] == (0.0, 0.0, 0.0) @@ -4239,7 +4256,9 @@ def test_utm(self): ) # Check no parse coords - with xr.open_rasterio(tmp_file, parse_coordinates=False) as rioda: + with pytest.warns(DeprecationWarning), xr.open_rasterio( + tmp_file, parse_coordinates=False + ) as rioda: assert "x" not in rioda.coords assert "y" not in rioda.coords @@ -4251,7 +4270,7 @@ def test_non_rectilinear(self): transform=from_origin(0, 3, 1, 1).rotation(45), crs=None ) as (tmp_file, _): # Default is to not parse coords - with xr.open_rasterio(tmp_file) as rioda: + with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: assert "x" not in rioda.coords assert "y" not in rioda.coords assert "crs" not in rioda.attrs @@ -4279,7 +4298,7 @@ def test_platecarree(self): crs="+proj=latlong", open_kwargs={"nodata": -9765}, ) as (tmp_file, expected): - with xr.open_rasterio(tmp_file) as rioda: + with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: assert_allclose(rioda, expected) assert rioda.attrs["scales"] == (1.0,) assert rioda.attrs["offsets"] == (0.0,) @@ -4327,7 +4346,7 @@ def test_notransform(self): "x": [0.5, 1.5, 2.5, 3.5], }, ) - with xr.open_rasterio(tmp_file) as rioda: + with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: assert_allclose(rioda, expected) assert rioda.attrs["scales"] == (1.0, 1.0, 1.0) assert rioda.attrs["offsets"] == (0.0, 0.0, 0.0) @@ -4342,7 +4361,9 @@ def test_indexing(self): with create_tmp_geotiff( 8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong" ) as (tmp_file, expected): - with xr.open_rasterio(tmp_file, cache=False) as actual: + with pytest.warns(DeprecationWarning), xr.open_rasterio( + tmp_file, cache=False + ) as actual: # tests # assert_allclose checks all data + coordinates @@ -4458,7 +4479,7 @@ def test_caching(self): 8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong" ) as (tmp_file, expected): # Cache is the default - with xr.open_rasterio(tmp_file) as actual: + with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as actual: # This should cache everything assert_allclose(actual, expected) @@ -4474,7 +4495,9 @@ def test_chunks(self): 8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong" ) as (tmp_file, expected): # Chunk at open time - with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: + with pytest.warns(DeprecationWarning), xr.open_rasterio( + tmp_file, chunks=(1, 2, 2) + ) as actual: import dask.array as da @@ -4496,7 +4519,7 @@ def test_chunks(self): def test_pickle_rasterio(self): # regression test for https://github.com/pydata/xarray/issues/2121 with create_tmp_geotiff() as (tmp_file, expected): - with xr.open_rasterio(tmp_file) as rioda: + with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: temp = pickle.dumps(rioda) with pickle.loads(temp) as actual: assert_equal(actual, rioda) @@ -4548,7 +4571,7 @@ def test_ENVI_tags(self): } expected = DataArray(data, dims=("band", "y", "x"), coords=coords) - with xr.open_rasterio(tmp_file) as rioda: + with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: assert_allclose(rioda, expected) assert isinstance(rioda.attrs["crs"], str) assert isinstance(rioda.attrs["res"], tuple) @@ -4563,7 +4586,7 @@ def test_ENVI_tags(self): def test_geotiff_tags(self): # Create a geotiff file with some tags with create_tmp_geotiff() as (tmp_file, _): - with xr.open_rasterio(tmp_file) as rioda: + with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: assert isinstance(rioda.attrs["AREA_OR_POINT"], str) @requires_dask @@ -4578,7 +4601,9 @@ def test_no_mftime(self): 8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong" ) as (tmp_file, expected): with mock.patch("os.path.getmtime", side_effect=OSError): - with xr.open_rasterio(tmp_file, chunks=(1, 2, 2)) as actual: + with pytest.warns(DeprecationWarning), xr.open_rasterio( + tmp_file, chunks=(1, 2, 2) + ) as actual: import dask.array as da assert isinstance(actual.data, da.Array) @@ -4589,10 +4614,12 @@ def test_http_url(self): # more examples urls here # http://download.osgeo.org/geotiff/samples/ url = "http://download.osgeo.org/geotiff/samples/made_up/ntf_nord.tif" - with xr.open_rasterio(url) as actual: + with pytest.warns(DeprecationWarning), xr.open_rasterio(url) as actual: assert actual.shape == (1, 512, 512) # make sure chunking works - with xr.open_rasterio(url, chunks=(1, 256, 256)) as actual: + with pytest.warns(DeprecationWarning), xr.open_rasterio( + url, chunks=(1, 256, 256) + ) as actual: import dask.array as da assert isinstance(actual.data, da.Array) @@ -4604,7 +4631,9 @@ def test_rasterio_environment(self): # Should fail with error since suffix not allowed with pytest.raises(Exception): with rasterio.Env(GDAL_SKIP="GTiff"): - with xr.open_rasterio(tmp_file) as actual: + with pytest.warns(DeprecationWarning), xr.open_rasterio( + tmp_file + ) as actual: assert_allclose(actual, expected) @pytest.mark.xfail(reason="rasterio 1.1.1 is broken. GH3573") @@ -4621,7 +4650,7 @@ def test_rasterio_vrt(self): # Value of single pixel in center of image lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2) expected_val = next(vrt.sample([(lon, lat)])) - with xr.open_rasterio(vrt) as da: + with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da: actual_shape = (da.sizes["x"], da.sizes["y"]) actual_crs = da.crs actual_res = da.res @@ -4675,7 +4704,7 @@ def test_rasterio_vrt_with_src_crs(self): with rasterio.open(tmp_file) as src: assert src.crs is None with rasterio.vrt.WarpedVRT(src, src_crs=src_crs) as vrt: - with xr.open_rasterio(vrt) as da: + with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da: assert da.crs == src_crs @network @@ -4695,7 +4724,7 @@ def test_rasterio_vrt_network(self): # Value of single pixel in center of image lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2) expected_val = next(vrt.sample([(lon, lat)])) - with xr.open_rasterio(vrt) as da: + with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da: actual_shape = da.sizes["x"], da.sizes["y"] actual_res = da.res actual_val = da.sel(dict(x=lon, y=lat), method="nearest").data diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index cd62ebd4239..352ec6c10f1 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -26,10 +26,11 @@ def test_custom_engine() -> None: class CustomBackend(xr.backends.BackendEntrypoint): def open_dataset( + self, filename_or_obj, drop_variables=None, **kwargs, - ): + ) -> xr.Dataset: return expected.copy(deep=True) actual = xr.open_dataset("fake_filename", engine=CustomBackend) diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 725b5efee75..619fb0acdc4 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -1,4 +1,5 @@ from datetime import timedelta +from distutils.version import LooseVersion from textwrap import dedent import numpy as np @@ -345,65 +346,86 @@ def test_get_loc(date_type, index): @requires_cftime -@pytest.mark.parametrize("kind", ["loc", "getitem"]) -def test_get_slice_bound(date_type, index, kind): - result = index.get_slice_bound("0001", "left", kind) +def test_get_slice_bound(date_type, index): + # The kind argument is required in earlier versions of pandas even though it + # is not used by CFTimeIndex. This logic can be removed once our minimum + # version of pandas is at least 1.3. + if LooseVersion(pd.__version__) < LooseVersion("1.3"): + kind_args = ("getitem",) + else: + kind_args = () + + result = index.get_slice_bound("0001", "left", *kind_args) expected = 0 assert result == expected - result = index.get_slice_bound("0001", "right", kind) + result = index.get_slice_bound("0001", "right", *kind_args) expected = 2 assert result == expected - result = index.get_slice_bound(date_type(1, 3, 1), "left", kind) + result = index.get_slice_bound(date_type(1, 3, 1), "left", *kind_args) expected = 2 assert result == expected - result = index.get_slice_bound(date_type(1, 3, 1), "right", kind) + result = index.get_slice_bound(date_type(1, 3, 1), "right", *kind_args) expected = 2 assert result == expected @requires_cftime -@pytest.mark.parametrize("kind", ["loc", "getitem"]) -def test_get_slice_bound_decreasing_index(date_type, monotonic_decreasing_index, kind): - result = monotonic_decreasing_index.get_slice_bound("0001", "left", kind) +def test_get_slice_bound_decreasing_index(date_type, monotonic_decreasing_index): + # The kind argument is required in earlier versions of pandas even though it + # is not used by CFTimeIndex. This logic can be removed once our minimum + # version of pandas is at least 1.3. + if LooseVersion(pd.__version__) < LooseVersion("1.3"): + kind_args = ("getitem",) + else: + kind_args = () + + result = monotonic_decreasing_index.get_slice_bound("0001", "left", *kind_args) expected = 2 assert result == expected - result = monotonic_decreasing_index.get_slice_bound("0001", "right", kind) + result = monotonic_decreasing_index.get_slice_bound("0001", "right", *kind_args) expected = 4 assert result == expected result = monotonic_decreasing_index.get_slice_bound( - date_type(1, 3, 1), "left", kind + date_type(1, 3, 1), "left", *kind_args ) expected = 2 assert result == expected result = monotonic_decreasing_index.get_slice_bound( - date_type(1, 3, 1), "right", kind + date_type(1, 3, 1), "right", *kind_args ) expected = 2 assert result == expected @requires_cftime -@pytest.mark.parametrize("kind", ["loc", "getitem"]) -def test_get_slice_bound_length_one_index(date_type, length_one_index, kind): - result = length_one_index.get_slice_bound("0001", "left", kind) +def test_get_slice_bound_length_one_index(date_type, length_one_index): + # The kind argument is required in earlier versions of pandas even though it + # is not used by CFTimeIndex. This logic can be removed once our minimum + # version of pandas is at least 1.3. + if LooseVersion(pd.__version__) <= LooseVersion("1.3"): + kind_args = ("getitem",) + else: + kind_args = () + + result = length_one_index.get_slice_bound("0001", "left", *kind_args) expected = 0 assert result == expected - result = length_one_index.get_slice_bound("0001", "right", kind) + result = length_one_index.get_slice_bound("0001", "right", *kind_args) expected = 1 assert result == expected - result = length_one_index.get_slice_bound(date_type(1, 3, 1), "left", kind) + result = length_one_index.get_slice_bound(date_type(1, 3, 1), "left", *kind_args) expected = 1 assert result == expected - result = length_one_index.get_slice_bound(date_type(1, 3, 1), "right", kind) + result = length_one_index.get_slice_bound(date_type(1, 3, 1), "right", *kind_args) expected = 1 assert result == expected diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 3ca964b94e1..cbe09aab815 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -1097,7 +1097,12 @@ def test_combine_by_coords_raises_for_differing_calendars(): da_2 = DataArray([1], dims=["time"], coords=[time_2], name="a").to_dataset() if LooseVersion(cftime.__version__) >= LooseVersion("1.5"): - error_msg = "Cannot combine along dimension 'time' with mixed types." + error_msg = ( + "Cannot combine along dimension 'time' with mixed types." + " Found:.*" + " If importing data directly from a file then setting" + " `use_cftime=True` may fix this issue." + ) else: error_msg = r"cannot compare .* \(different calendars\)" diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index c3223432b38..b1bd7576a12 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3378,19 +3378,10 @@ def test_roll_coords(self): def test_roll_no_coords(self): arr = DataArray([1, 2, 3], coords={"x": range(3)}, dims="x") - actual = arr.roll(x=1, roll_coords=False) + actual = arr.roll(x=1) expected = DataArray([3, 1, 2], coords=[("x", [0, 1, 2])]) assert_identical(expected, actual) - def test_roll_coords_none(self): - arr = DataArray([1, 2, 3], coords={"x": range(3)}, dims="x") - - with pytest.warns(FutureWarning): - actual = arr.roll(x=1, roll_coords=None) - - expected = DataArray([3, 1, 2], coords=[("x", [2, 0, 1])]) - assert_identical(expected, actual) - def test_copy_with_data(self): orig = DataArray( np.random.random(size=(2, 2)), diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7b17eae89c8..61b404275bf 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3395,6 +3395,9 @@ def test_setitem(self): # override an existing value data1["A"] = 3 * data2["A"] assert_equal(data1["A"], 3 * data2["A"]) + # can't assign a dataset to a single key + with pytest.raises(TypeError, match="Cannot assign a Dataset to a single key"): + data1["D"] = xr.Dataset() # test assignment with positional and label-based indexing data3 = data1[["var1", "var2"]] @@ -5098,25 +5101,13 @@ def test_roll_no_coords(self): coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} attrs = {"meta": "data"} ds = Dataset({"foo": ("x", [1, 2, 3])}, coords, attrs) - actual = ds.roll(x=1, roll_coords=False) + actual = ds.roll(x=1) expected = Dataset({"foo": ("x", [3, 1, 2])}, coords, attrs) assert_identical(expected, actual) with pytest.raises(ValueError, match=r"dimensions"): - ds.roll(abc=321, roll_coords=False) - - def test_roll_coords_none(self): - coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} - attrs = {"meta": "data"} - ds = Dataset({"foo": ("x", [1, 2, 3])}, coords, attrs) - - with pytest.warns(FutureWarning): - actual = ds.roll(x=1, roll_coords=None) - - ex_coords = {"bar": ("x", list("cab")), "x": [2, -4, 3]} - expected = Dataset({"foo": ("x", [3, 1, 2])}, ex_coords, attrs) - assert_identical(expected, actual) + ds.roll(abc=321) def test_roll_multidim(self): # regression test for 2445 diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index ef1ce50d6ea..92f39069aa3 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -163,7 +163,7 @@ def test_dask_distributed_zarr_integration_test(loop, consolidated, compute) -> def test_dask_distributed_rasterio_integration_test(loop) -> None: with create_tmp_geotiff() as (tmp_file, expected): with cluster() as (s, [a, b]): - with Client(s["address"], loop=loop): + with pytest.warns(DeprecationWarning), Client(s["address"], loop=loop): da_tiff = xr.open_rasterio(tmp_file, chunks={"band": 1}) assert isinstance(da_tiff.data, da.Array) actual = da_tiff.compute() diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b822ba42ce5..3260b92bd71 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -804,8 +804,9 @@ def test_xlabel_uses_name(self): assert "testpoints [testunits]" == plt.gca().get_xlabel() def test_title_is_histogram(self): + self.darray.coords["d"] = 10 self.darray.plot.hist() - assert "Histogram" == plt.gca().get_title() + assert "d = 10" == plt.gca().get_title() def test_can_pass_in_kwargs(self): nbins = 5 diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py index 411ad52368d..e4c4378afdd 100644 --- a/xarray/tests/test_tutorial.py +++ b/xarray/tests/test_tutorial.py @@ -32,11 +32,11 @@ def test_download_rasterio_from_github_load_without_cache( self, tmp_path, monkeypatch ): cache_dir = tmp_path / tutorial._default_cache_dir_name - - arr_nocache = tutorial.open_rasterio( - "RGB.byte", cache=False, cache_dir=cache_dir - ).load() - arr_cache = tutorial.open_rasterio( - "RGB.byte", cache=True, cache_dir=cache_dir - ).load() + with pytest.warns(DeprecationWarning): + arr_nocache = tutorial.open_rasterio( + "RGB.byte", cache=False, cache_dir=cache_dir + ).load() + arr_cache = tutorial.open_rasterio( + "RGB.byte", cache=True, cache_dir=cache_dir + ).load() assert_identical(arr_cache, arr_nocache) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 543100ef98c..7bde6ce8b9f 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -13,6 +13,7 @@ assert_duckarray_allclose, assert_equal, assert_identical, + requires_dask, requires_matplotlib, ) from .test_plot import PlotTestCase @@ -5579,6 +5580,24 @@ def test_merge(self, variant, unit, error, dtype): assert_equal(expected, actual) +@requires_dask +class TestPintWrappingDask: + def test_duck_array_ops(self): + import dask.array + + d = dask.array.array([1, 2, 3]) + q = pint.Quantity(d, units="m") + da = xr.DataArray(q, dims="x") + + actual = da.mean().compute() + actual.name = None + expected = xr.DataArray(pint.Quantity(np.array(2.0), units="m")) + + assert_units_equal(expected, actual) + # Don't use isinstance b/c we don't want to allow subclasses through + assert type(expected.data) == type(actual.data) # noqa + + @requires_matplotlib class TestPlots(PlotTestCase): def test_units_in_line_plot_labels(self): diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index b80175273e0..7f6eed55e9b 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -20,7 +20,6 @@ from .core.dataarray import DataArray as _DataArray from .core.dataset import Dataset as _Dataset -from .core.duck_array_ops import _dask_or_eager_func from .core.groupby import GroupBy as _GroupBy from .core.pycompat import dask_array_type as _dask_array_type from .core.variable import Variable as _Variable @@ -71,7 +70,7 @@ def __call__(self, *args, **kwargs): new_args = tuple(reversed(args)) if res is _UNDEFINED: - f = _dask_or_eager_func(self._name, array_args=slice(len(args))) + f = getattr(_np, self._name) res = f(*new_args, **kwargs) if res is NotImplemented: raise TypeError(