Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Kvikio backend entrypoint #10

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9deadb7
Add Kvikio backend entrypoint
dcherian Aug 2, 2022
aa2dc91
Add demo notebook
dcherian Aug 2, 2022
7fb4b94
Update kvikio notebook
dcherian Aug 16, 2022
743fe7d
Merge branch 'main' into kvikio-entrypoint
dcherian Aug 17, 2022
5d501e4
Merge branch 'main' into kvikio-entrypoint
andersy005 Aug 17, 2022
facf5f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2022
f3f5189
Update cupy_xarray/kvikio.py
dcherian Aug 17, 2022
9c98d19
Merge branch 'main' into kvikio-entrypoint
andersy005 Jan 3, 2023
dd8bc57
Merge branch 'main' into kvikio-entrypoint
dcherian Jan 20, 2023
d2da1e4
Add url, description.
dcherian Jan 21, 2023
b87c3c2
Working
dcherian Aug 18, 2023
87cb74e
Updated notebook
dcherian Aug 22, 2023
d7394ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2023
1b23fef
Merge remote-tracking branch 'upstream/main' into kvikio-entrypoint
dcherian Nov 3, 2023
ca0cf45
Add tests
dcherian Nov 3, 2023
97260d6
Merge branch 'main' into kvikio-entrypoint
weiji14 Jun 21, 2024
5d27b26
Move kvikio notebook under docs/source
weiji14 Jun 21, 2024
85491d7
Add zarr as a dependency in ci/doc.yml
weiji14 Jun 22, 2024
c470b97
Add entry for KvikioBackendEntrypoint in API docs
weiji14 Jun 22, 2024
95efa18
Fix input argument into CupyZarrArrayWrapper
weiji14 Jun 22, 2024
d684dad
Merge branch 'main' into kvikio-entrypoint
weiji14 Dec 14, 2024
ae2a7f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 14, 2024
15fbafd
Re-add kvikio backend entrypoint to pyproject.toml
weiji14 Dec 14, 2024
f3df115
Fix C408 and E402
weiji14 Dec 14, 2024
4e1857a
Use get_duck_array instead of get_array
weiji14 Dec 14, 2024
7345b61
Fix SIM108 Use ternary operator
weiji14 Dec 16, 2024
e2b410e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions cupy_xarray/kvikio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import os
import warnings

import cupy as cp
import numpy as np
import zarr
from xarray import Variable
from xarray.backends import zarr as zarr_backend
from xarray.backends.common import _normalize_path # TODO: can this be public
from xarray.backends.store import StoreBackendEntrypoint
from xarray.backends.zarr import ZarrArrayWrapper, ZarrBackendEntrypoint, ZarrStore
from xarray.core import indexing
from xarray.core.utils import close_on_error # TODO: can this be public.

try:
import kvikio.zarr

has_kvikio = True
except ImportError:
has_kvikio = False


class DummyZarrArrayWrapper(ZarrArrayWrapper):
def __init__(self, array: np.ndarray):
assert isinstance(array, np.ndarray)
self._array = array
self.filters = None
self.dtype = array.dtype
self.shape = array.shape

def __array__(self):
return self._array

def get_array(self):
return self._array

def __getitem__(self, key):
return self._array[key]


class CupyZarrArrayWrapper(ZarrArrayWrapper):
def __array__(self):
return self.get_array()


class EagerCupyZarrArrayWrapper(ZarrArrayWrapper):
"""Used to wrap dimension coordinates."""

def __array__(self):
return self.datastore.zarr_group[self.variable_name][:].get()

def get_array(self):
# total hack: make a numpy array look like a Zarr array
return DummyZarrArrayWrapper(self.datastore.zarr_group[self.variable_name][:].get())


class GDSZarrStore(ZarrStore):
@classmethod
def open_group(
cls,
store,
mode="r",
synchronizer=None,
group=None,
consolidated=False,
consolidate_on_close=False,
chunk_store=None,
storage_options=None,
append_dim=None,
write_region=None,
safe_chunks=True,
stacklevel=2,
):
# zarr doesn't support pathlib.Path objects yet. zarr-python#601
if isinstance(store, os.PathLike):
store = os.fspath(store)

open_kwargs = dict(
mode=mode,
synchronizer=synchronizer,
path=group,
########## NEW STUFF
meta_array=cp.empty(()),
)
open_kwargs["storage_options"] = storage_options

# TODO: handle consolidated
assert not consolidated

if chunk_store:
open_kwargs["chunk_store"] = chunk_store
if consolidated is None:
consolidated = False

store = kvikio.zarr.GDSStore(store)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can refactor this to use kvikio.zarr.open_cupy_array once kvikio=23.10 is out? There's support for nvCOMP-based LZ4 compression now (that's compatible with Zarr's CPU-based LZ4 compressor), xref rapidsai/kvikio#267.

Copy link
Contributor Author

@dcherian dcherian Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR welcome!

Copy link
Member

@weiji14 weiji14 Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR welcome!

Debating on whether to start from scratch in a completely new branch, or rebase off of this one 😄

P.S. I'm starting some work over at https://github.com/weiji14/foss4g2023oceania for a conference talk on 18 Oct, hoping to get the kvikIO engine running on an ERA5 dataset. Somehow, I could only get things on this branch to work up to kvikIO==23.06, it seems like RAPIDS AI 23.08 moved to CUDA 12 and I've been getting some errors like RuntimeError: Unable to open file: Too many open files. I'll try to push some code in the next 2 weeks, and should be able to nudge this forward a little bit 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just build on top of this branch in a new PR.

The optimization i was mentioning to save self.datastore.zarr_group[self.variable_name] as self._array. Otherwise we keep pinging the store unnecessarily

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I got a bit too ambitious trying to work in support for LZ4 compression (via nvCOMP), and hit into some issues (see rapidsai/kvikio#297). There's some stable releases of RAPIDS AI kvikIO 23.10.00 and xarray 2023.10.0 now which should have some nice enhancements for this PR. I'll try to squeeze out some time to work on it.


if consolidated is None:
try:
zarr_group = zarr.open_consolidated(store, **open_kwargs)
except KeyError:
warnings.warn(
"Failed to open Zarr store with consolidated metadata, "
"falling back to try reading non-consolidated metadata. "
"This is typically much slower for opening a dataset. "
"To silence this warning, consider:\n"
"1. Consolidating metadata in this existing store with "
"zarr.consolidate_metadata().\n"
"2. Explicitly setting consolidated=False, to avoid trying "
"to read consolidate metadata, or\n"
"3. Explicitly setting consolidated=True, to raise an "
"error in this case instead of falling back to try "
"reading non-consolidated metadata.",
RuntimeWarning,
stacklevel=stacklevel,
)
zarr_group = zarr.open_group(store, **open_kwargs)
elif consolidated:
# TODO: an option to pass the metadata_key keyword
zarr_group = zarr.open_consolidated(store, **open_kwargs)
else:
zarr_group = zarr.open_group(store, **open_kwargs)

return cls(
zarr_group,
mode,
consolidate_on_close,
append_dim,
write_region,
safe_chunks,
)

def open_store_variable(self, name, zarr_array):
try_nczarr = self._mode == "r"
dimensions, attributes = zarr_backend._get_zarr_dims_and_attrs(
zarr_array, zarr_backend.DIMENSION_KEY, try_nczarr
)

#### Changed from zarr array wrapper
if name in dimensions:
# we want indexed dimensions to be loaded eagerly
# Right now we load in to device and then transfer to host
# But these should be small-ish arrays
# TODO: can we tell GDSStore to load as numpy array directly
# not cupy array?
array_wrapper = EagerCupyZarrArrayWrapper
else:
array_wrapper = CupyZarrArrayWrapper
data = indexing.LazilyIndexedArray(array_wrapper(name, self))

attributes = dict(attributes)
encoding = {
"chunks": zarr_array.chunks,
"preferred_chunks": dict(zip(dimensions, zarr_array.chunks)),
"compressor": zarr_array.compressor,
"filters": zarr_array.filters,
}
# _FillValue needs to be in attributes, not encoding, so it will get
# picked up by decode_cf
if getattr(zarr_array, "fill_value") is not None:
attributes["_FillValue"] = zarr_array.fill_value

return Variable(dimensions, data, attributes, encoding)


class KvikioBackendEntrypoint(ZarrBackendEntrypoint):
available = has_kvikio
description = "Open zarr files (.zarr) using Kvikio"
url = "https://docs.rapids.ai/api/kvikio/nightly/api.html#zarr"

# disabled by default
# We need to provide this because of the subclassing from
# ZarrBackendEntrypoint
def guess_can_open(self, filename_or_obj):
return False

def open_dataset(
self,
filename_or_obj,
mask_and_scale=True,
decode_times=True,
concat_characters=True,
decode_coords=True,
drop_variables=None,
use_cftime=None,
decode_timedelta=None,
group=None,
mode="r",
synchronizer=None,
consolidated=None,
chunk_store=None,
storage_options=None,
stacklevel=3,
):
filename_or_obj = _normalize_path(filename_or_obj)
store = GDSZarrStore.open_group(
filename_or_obj,
group=group,
mode=mode,
synchronizer=synchronizer,
consolidated=consolidated,
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
)

store_entrypoint = StoreBackendEntrypoint()
with close_on_error(store):
ds = store_entrypoint.open_dataset(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
return ds
Loading