Skip to content

Commit

Permalink
Merge pull request #86 from rabernat/write-with-zarr
Browse files Browse the repository at this point in the history
Write with Zarr
  • Loading branch information
rabernat authored Apr 2, 2021
2 parents 3314162 + 449d1d8 commit 721dea4
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 81 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8]
python-version: [3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
- name: Setup Python
Expand Down Expand Up @@ -48,7 +48,7 @@ jobs:
- name: install pangeo-forge
shell: bash -l {0}
run: |
python -m pip install -e .
python -m pip install --no-deps -e .
- name: Run Tests
shell: bash -l {0}
run: |
Expand Down
10 changes: 7 additions & 3 deletions ci/py3.7.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ channels:
- conda-forge
dependencies:
- python=3.7
- aiohttp
- black
- boto3
- cfgrib
- codecov
- dask
- distributed
- fsspec
# - fsspec
- h5netcdf
- h5py
- hdf5
Expand All @@ -25,10 +26,13 @@ dependencies:
- pytest-cov
- pytest-lazy-fixture
- rasterio
- requests
- scipy
- setuptools
- toolz
- xarray>=0.16.2
# - xarray>=0.16.2
- zarr>=2.6.0
- pip:
- git+https://github.com/rabernat/rechunker.git@refactor-executors
- git+https://github.com/rabernat/xarray.git@zarr-chunk-fixes
- git+https://github.com/pangeo-data/rechunker.git@master
- git+https://github.com/intake/filesystem_spec.git@master
10 changes: 7 additions & 3 deletions ci/py3.8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ channels:
- conda-forge
dependencies:
- python=3.8
- aiohttp
- black
- boto3
- cfgrib
- codecov
- dask
- distributed
- fsspec
# - fsspec
- h5netcdf
- h5py
- hdf5
Expand All @@ -25,10 +26,13 @@ dependencies:
- pytest-cov
- pytest-lazy-fixture
- rasterio
- requests
- scipy
- setuptools
- toolz
- xarray>=0.16.2
# - xarray>=0.16.2
- zarr>=2.6.0
- pip:
- git+https://github.com/rabernat/rechunker.git@refactor-executors
- git+https://github.com/rabernat/xarray.git@zarr-chunk-fixes
- git+https://github.com/pangeo-data/rechunker.git@master
- git+https://github.com/intake/filesystem_spec.git@master
39 changes: 39 additions & 0 deletions ci/py3.9.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: pangeo-forge
channels:
- conda-forge
dependencies:
- python=3.9
- aiohttp
- black
- boto3
- cfgrib
- codecov
- dask
# - distributed
- fsspec
- h5netcdf
- h5py
- hdf5
- lxml # Optional dep of pydap
- netcdf4
- numpy
- pandas
- pip
- prefect
- pydap
# bring back eventually once pynio conda-forge package supports py3.9
# - pynio
- pytest
- pytest-cov
- pytest-lazy-fixture
- rasterio
- requests
- scipy
- setuptools
- toolz
# - xarray>=0.16.2
- zarr>=2.6.0
- pip:
- git+https://github.com/rabernat/xarray.git@zarr-chunk-fixes
- git+https://github.com/pangeo-data/rechunker.git@master
- git+https://github.com/intake/filesystem_spec.git@master
100 changes: 75 additions & 25 deletions pangeo_forge/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
from dataclasses import dataclass, field
from typing import Callable, Dict, Hashable, Iterable, List, Optional

import dask
import fsspec
import numpy as np
import xarray as xr
import zarr
from rechunker.types import MultiStagePipeline, ParallelPipelines, Stage

from .patterns import ExplicitURLSequence, VariableSequencePattern
from .storage import AbstractTarget, UninitializedTarget
from .utils import (
calc_chunk_conflicts,
chunk_bounds_and_conflicts,
chunked_iterable,
fix_scalar_attr_encoding,
lock_for_conflicts,
Expand Down Expand Up @@ -217,6 +219,7 @@ def __post_init__(self):
self._sequence_dim_chunks = self.nitems_per_input * self.inputs_per_chunk

# TODO: more input validation
# for example: check required args (e.g. sequence_dim)

@property
def prepare_target(self) -> Callable:
Expand All @@ -237,8 +240,7 @@ def _prepare_target():
# need to rewrite this as an append loop
for chunk_key in self._init_chunks:
with self.open_chunk(chunk_key) as ds:
# need to have the data in memory to avoid weird chunk problems
ds.load()
# ds is already chunked

# https://github.com/pydata/xarray/blob/5287c7b2546fc8848f539bb5ee66bb8d91d8496f/xarray/core/variable.py#L1069
for v in ds.variables:
Expand All @@ -253,12 +255,25 @@ def _prepare_target():
chunks = tuple(
chunks.get(n, s) for n, s in enumerate(this_var.shape)
)
ds[v].encoding["chunks"] = chunks
encoding_chunks = chunks
else:
ds[v].encoding["chunks"] = ds[v].shape
encoding_chunks = ds[v].shape
logger.debug(
f"Setting variable {v} encoding chunks to {encoding_chunks}"
)
ds[v].encoding["chunks"] = encoding_chunks

# load all variables that don't have the sequence dim in them
# these are usually coordinates.
# Variables that are loaded will be written even with compute=False
# TODO: make this behavior customizable
for v in ds.variables:
if self.sequence_dim not in ds[v].dims:
ds[v].load()

target_mapper = self.target.get_mapper()
ds.to_zarr(target_mapper, mode="a", compute=False)
logger.debug(f"Storing dataset:\n {ds}")
ds.to_zarr(target_mapper, mode="a", compute=False, safe_chunks=False)

# Regardless of whether there is an existing dataset or we are creating a new one,
# we need to expand the sequence_dim to hold the entire expected size of the data
Expand Down Expand Up @@ -305,11 +320,36 @@ def _store_chunk(chunk_key):
v for v in ds_chunk.variables if self.sequence_dim not in ds_chunk[v].dims
]
ds_chunk = ds_chunk.drop_vars(to_drop)

target_mapper = self.target.get_mapper()
write_region, conflicts = self.region_and_conflicts_for_chunk(chunk_key)
with lock_for_conflicts(conflicts):
logger.info(f"Storing chunk '{chunk_key}' to Zarr region {write_region}")
ds_chunk.to_zarr(target_mapper, region=write_region)

zgroup = zarr.open_group(target_mapper)
for vname, var_coded in ds_chunk.variables.items():
zarr_array = zgroup[vname]
# get encoding for variable from zarr attributes
# could this backfire some way?
var_coded.encoding.update(zarr_array.attrs)
# just delete all attributes from the var;
# they are not used anyway, and there can be conflicts
# related to xarray.coding.variables.safe_setitem
var_coded.attrs = {}
with dask.config.set(
scheduler="single-threaded"
): # make sure we don't use a scheduler
var = xr.backends.zarr.encode_zarr_variable(var_coded)
data = np.asarray(
var.data
) # TODO: can we buffer large data rather than loading it all?
zarr_region = tuple(write_region.get(dim, slice(None)) for dim in var.dims)
lock_keys = [f"{vname}-{c}" for c in conflicts]
logger.debug(f"Acquiring locks {lock_keys}")
with lock_for_conflicts(lock_keys):
logger.info(
f"Storing variable {vname} chunk {chunk_key} "
f"to Zarr region {zarr_region}"
)
zarr_array[zarr_region] = data

return _store_chunk

Expand Down Expand Up @@ -382,12 +422,15 @@ def open_chunk(self, chunk_key):
with ExitStack() as stack:
dsets = [stack.enter_context(self.open_input(i)) for i in inputs]
# explicitly chunking prevents eager evaluation during concat
# dsets = [ds.chunk() for ds in dsets]
# but that leads to corrupted data!

# CONCAT DELETES ENCODING!!!
# OR NO IT DOESN'T! Not in the latest version of xarray?
ds = xr.concat(dsets, self.sequence_dim, **self.xarray_concat_kwargs)
dsets = [ds.chunk() for ds in dsets]
if len(dsets) > 1:
# During concat, attributes and encoding are taken from the first dataset
# https://github.com/pydata/xarray/issues/1614
ds = xr.concat(dsets, self.sequence_dim, **self.xarray_concat_kwargs)
elif len(dsets) == 1:
ds = dsets[0]
else: # pragma: no cover
assert False, "Should never happen"

if self.process_chunk is not None:
ds = self.process_chunk(ds)
Expand All @@ -404,14 +447,14 @@ def open_target(self):
def expand_target_dim(self, dim, dimsize):
target_mapper = self.target.get_mapper()
zgroup = zarr.open_group(target_mapper)

ds = self.open_target()
sequence_axes = {v: ds[v].get_axis_num(dim) for v in ds.variables if dim in ds[v].dims}

for v, axis in sequence_axes.items():
arr = zgroup[v]
shape = list(arr.shape)
shape[axis] = dimsize
logger.debug(f"resizing array {v} to shape {shape}")
arr.resize(shape)

# now explicity write the sequence coordinate to avoid missing data
Expand All @@ -433,22 +476,29 @@ def region_and_conflicts_for_chunk(self, chunk_key):
# also return the conflicts with other chunks

input_keys = self.inputs_for_chunk(chunk_key)

# TODO: refactor into a separate method
if self.nitems_per_input:
stride = self.nitems_per_input * self.inputs_per_chunk
start = self.chunk_position(chunk_key) * stride
stop = start + stride
input_sequence_lens = (self.nitems_per_input,) * self._n_inputs_along_sequence
else:
input_sequence_lens = json.loads(
self.metadata_cache.get_mapper()[_GLOBAL_METADATA_KEY]
)["input_sequence_lens"]
start = sum(input_sequence_lens[: self.input_position(input_keys[0])])
chunk_len = sum([input_sequence_lens[self.input_position(k)] for k in input_keys])
stop = start + chunk_len

all_chunk_conflicts = calc_chunk_conflicts(input_sequence_lens, self._sequence_dim_chunks)
this_chunk_conflicts = [all_chunk_conflicts[self.input_position(k)] for k in input_keys]

chunk_bounds, all_chunk_conflicts = chunk_bounds_and_conflicts(
input_sequence_lens, self._sequence_dim_chunks
)
# for multi-variable recipes, there is something redunandt about this
# logic that feels error prone
start = chunk_bounds[self.input_position(input_keys[0])]
stop = chunk_bounds[self.input_position(input_keys[-1]) + 1]

this_chunk_conflicts = set()
for k in input_keys:
# for multi-variable recipes, the confilcts will usually be the same
# for each variable. using a set avoids duplicate locks
for input_conflict in all_chunk_conflicts[self.input_position(k)]:
this_chunk_conflicts.add(input_conflict)
region_slice = slice(start, stop)
return {self.sequence_dim: region_slice}, this_chunk_conflicts

Expand Down
45 changes: 36 additions & 9 deletions pangeo_forge/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import itertools
import logging
from contextlib import contextmanager
from typing import List, Sequence, Tuple

import numpy as np
from dask.distributed import Lock, client
from dask.distributed import Lock, get_client

logger = logging.getLogger(__name__)


# https://alexwlchan.net/2018/12/iterating-in-fixed-size-chunks/
Expand Down Expand Up @@ -35,7 +38,21 @@ def _fixed_attrs(d):
return ds


def calc_chunk_conflicts(chunks: Sequence[int], zchunks: int) -> List[Tuple[int, ...]]:
def chunk_bounds_and_conflicts(chunks: Sequence[int], zchunks: int) -> List[Tuple[int, ...]]:
"""
Calculate the boundaries of contiguous put possibly uneven blocks over
a regularly chunked array
Parameters
----------
chunks : A list of chunk lengths. Len of array is the sum of each length.
zchunks : A constant on-disk chunk
Returns
-------
chunk_bounds : the boundaries of the regions to write (1 longer than chunks)
conflicts: a list of conflicts for each chunk, None for no conflicts
"""
n_chunks = len(chunks)

# coerce numpy array to list for mypy
Expand All @@ -59,24 +76,34 @@ def calc_chunk_conflicts(chunks: Sequence[int], zchunks: int) -> List[Tuple[int,
conflicts.add(chunk_pair[1])
chunk_conflicts.append(tuple(conflicts))

return chunk_conflicts
return chunk_bounds, chunk_conflicts


@contextmanager
# TODO: use a recipe-specific base_name to handle multiple recipes potentially
# running at the same time
def lock_for_conflicts(conflicts, base_name="pangeo-forge"):

# https://stackoverflow.com/questions/59070260/dask-client-detect-local-default-cluster-already-running
is_distributed = client._get_global_client() is not None
# Don't bother with locks if we are not in a distributed context
# NOTE! This means we HAVE to use dask.distributed as our parallel execution enviroment
# That is compatible with Prefect.
try:
global_client = get_client()
is_distributed = True
except ValueError:
# Don't bother with locks if we are not in a distributed context
# NOTE! This means we HAVE to use dask.distributed as our parallel execution enviroment
# This should be compatible with Prefect.
is_distributed = False
if is_distributed:
locks = [Lock(f"{base_name}-{c}") for c in conflicts]
locks = [Lock(f"{base_name}-{c}", global_client) for c in conflicts]
for lock in locks:
logger.debug(f"Acquiring lock {lock.name}...")
lock.acquire()
logger.debug(f"Acquired lock {lock.name}")
else:
logger.debug(f"Asked to lock {conflicts} but no Dask client found.")
try:
yield
finally:
if is_distributed:
for lock in locks:
lock.release()
logger.debug(f"Released lock {lock.name}")
Loading

0 comments on commit 721dea4

Please sign in to comment.