Skip to content

faster reading of h5ad file (~18X faster) #3365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
247 changes: 241 additions & 6 deletions src/scanpy/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path, PurePath
from typing import TYPE_CHECKING

import anndata
import anndata.utils
import h5py
import numpy as np
Expand All @@ -32,6 +33,12 @@
read_mtx,
read_text,
)
import multiprocessing as mp
import threading
from dataclasses import dataclass

import numba
import scipy
from anndata import AnnData
from matplotlib.image import imread

Expand All @@ -46,6 +53,14 @@

from ._utils import Empty

indices_type = np.int64
indices_shm_type = "l"

semDataLoaded = None # will be initialized later
semDataCopied = None # will be initialized later

thread_workload = 4000000 # experimented value

# .gz and .bz2 suffixes are also allowed for text formats
text_exts = {
"csv",
Expand All @@ -67,6 +82,224 @@
"""Available file formats for reading data. """


def get_1d_index(row: int, col: int, num_cols: int) -> int:
"""
Convert 2D coordinates to 1D index.

Parameters:
row (int): Row index in the 2D array.
col (int): Column index in the 2D array.
num_cols (int): Number of columns in the 2D array.

Returns:
int: Corresponding 1D index.
"""
return row * num_cols + col

Check warning on line 97 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L97

Added line #L97 was not covered by tests


@dataclass
class LoadHelperData:
i: int
k: int
datalen: int
dataArray: mp.Array
indicesArray: mp.Array
startsArray: mp.Array
endsArray: mp.Array


def _load_helper(fname: str, helper_data: LoadHelperData):
i = helper_data.i
k = helper_data.k
datalen = helper_data.datalen
dataArray = helper_data.dataArray
indicesArray = helper_data.indicesArray
startsArray = helper_data.startsArray
endsArray = helper_data.endsArray

Check warning on line 118 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L112-L118

Added lines #L112 - L118 were not covered by tests

f = h5py.File(fname, "r")
dataA = np.frombuffer(dataArray, dtype=np.float32)
indicesA = np.frombuffer(indicesArray, dtype=indices_type)
startsA = np.frombuffer(startsArray, dtype=np.int64)
endsA = np.frombuffer(endsArray, dtype=np.int64)
for j in range(datalen // (k * thread_workload) + 1):

Check warning on line 125 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L120-L125

Added lines #L120 - L125 were not covered by tests
# compute start, end
s = i * datalen // k + j * thread_workload
e = min(s + thread_workload, (i + 1) * datalen // k)
length = e - s
startsA[i] = s
endsA[i] = e

Check warning on line 131 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L127-L131

Added lines #L127 - L131 were not covered by tests
# read direct
f["X"]["data"].read_direct(

Check warning on line 133 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L133

Added line #L133 was not covered by tests
dataA, np.s_[s:e], np.s_[i * thread_workload : i * thread_workload + length]
)
f["X"]["indices"].read_direct(

Check warning on line 136 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L136

Added line #L136 was not covered by tests
indicesA,
np.s_[s:e],
np.s_[i * thread_workload : i * thread_workload + length],
)

# coordinate with copy threads
semDataLoaded[i].release() # done data load
semDataCopied[i].acquire() # wait until data copied

Check warning on line 144 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L143-L144

Added lines #L143 - L144 were not covered by tests


def _waitload(i):
semDataLoaded[i].acquire()

Check warning on line 148 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L148

Added line #L148 was not covered by tests


def _signalcopy(i):
semDataCopied[i].release()

Check warning on line 152 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L152

Added line #L152 was not covered by tests


@dataclass
class CopyData:
data: np.ndarray
dataA: np.ndarray
indices: np.ndarray
indicesA: np.ndarray
startsA: np.ndarray
endsA: np.ndarray


def _fast_copy(copy_data: CopyData, k: int, m: int):
# Access the arrays through copy_data
data = copy_data.data
dataA = copy_data.dataA
indices = copy_data.indices
indicesA = copy_data.indicesA
starts = copy_data.startsA
ends = copy_data.endsA

Check warning on line 172 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L167-L172

Added lines #L167 - L172 were not covered by tests

def thread_fun(i, m):
for j in range(m):
with numba.objmode():
_waitload(i)
length = ends[i] - starts[i]
data[starts[i] : ends[i]] = dataA[

Check warning on line 179 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L174-L179

Added lines #L174 - L179 were not covered by tests
i * thread_workload : i * thread_workload + length
]
indices[starts[i] : ends[i]] = indicesA[

Check warning on line 182 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L182

Added line #L182 was not covered by tests
i * thread_workload : i * thread_workload + length
]
with numba.objmode():
_signalcopy(i)

Check warning on line 186 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L185-L186

Added lines #L185 - L186 were not covered by tests

threads = [threading.Thread(target=thread_fun, args=(i, m)) for i in range(k)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

Check warning on line 192 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L188-L192

Added lines #L188 - L192 were not covered by tests


def fastload(fname, backed):
f = h5py.File(fname, backed)
assert "X" in f, "'X' is missing from f"
assert "var" in f, "'var' is missing from f"
assert "obs" in f, "'obs' is missing from f"

if not scipy.sparse.issparse(f["X"]):
f.close()
return read_h5ad(fname, backed=backed)

# get obs dataframe
rows = f["obs"][list(f["obs"].keys())[0]].size

Check warning on line 206 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L206

Added line #L206 was not covered by tests
# load index pointers, prepare shared arrays
indptr = f["X"]["indptr"][0 : rows + 1]
datalen = int(indptr[-1])

Check warning on line 209 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L208-L209

Added lines #L208 - L209 were not covered by tests

if datalen < thread_workload:
f.close()
return read_h5ad(fname, backed=backed)
if "_index" in f["obs"]:
dfobsind = pd.Series(f["obs"]["_index"].asstr()[0:rows])
dfobs = pd.DataFrame(index=dfobsind)

Check warning on line 216 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L211-L216

Added lines #L211 - L216 were not covered by tests
else:
dfobs = pd.DataFrame()
for k in f["obs"]:
if k == "_index":
continue
dfobs[k] = f["obs"][k].asstr()[...]

Check warning on line 222 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L218-L222

Added lines #L218 - L222 were not covered by tests

# get var dataframe
if "_index" in f["var"]:
dfvarind = pd.Series(f["var"]["_index"].asstr()[...])
dfvar = pd.DataFrame(index=dfvarind)

Check warning on line 227 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L225-L227

Added lines #L225 - L227 were not covered by tests
else:
dfvar = pd.DataFrame()
for k in f["var"]:
if k == "_index":
continue
dfvar[k] = f["var"][k].asstr()[...]

Check warning on line 233 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L229-L233

Added lines #L229 - L233 were not covered by tests

f.close()
k = numba.get_num_threads()
dataArray = mp.Array(

Check warning on line 237 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L235-L237

Added lines #L235 - L237 were not covered by tests
"f", k * thread_workload, lock=False
) # should be in shared memory
indicesArray = mp.Array(

Check warning on line 240 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L240

Added line #L240 was not covered by tests
indices_shm_type, k * thread_workload, lock=False
) # should be in shared memory
startsArray = mp.Array("l", k, lock=False) # start index of data read
endsArray = mp.Array("l", k, lock=False) # end index (noninclusive) of data read

Check warning on line 244 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L243-L244

Added lines #L243 - L244 were not covered by tests
global semDataLoaded
global semDataCopied
semDataLoaded = [mp.Semaphore(0) for _ in range(k)]
semDataCopied = [mp.Semaphore(0) for _ in range(k)]
dataA = np.frombuffer(dataArray, dtype=np.float32)
indicesA = np.frombuffer(indicesArray, dtype=indices_type)
startsA = np.frombuffer(startsArray, dtype=np.int64)
endsA = np.frombuffer(endsArray, dtype=np.int64)
data = np.empty(datalen, dtype=np.float32)
indices = np.empty(datalen, dtype=indices_type)

Check warning on line 254 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L247-L254

Added lines #L247 - L254 were not covered by tests

procs = [

Check warning on line 256 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L256

Added line #L256 was not covered by tests
mp.Process(
target=_load_helper,
args=(
fname,
LoadHelperData(
i=i,
k=k,
datalen=datalen,
dataArray=dataArray,
indicesArray=indicesArray,
startsArray=startsArray,
endsArray=endsArray,
),
),
)
for i in range(k)
]

for p in procs:
p.start()

Check warning on line 276 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L275-L276

Added lines #L275 - L276 were not covered by tests

copy_data = CopyData(

Check warning on line 278 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L278

Added line #L278 was not covered by tests
data=data,
dataA=dataA,
indices=indices,
indicesA=indicesA,
startsA=startsA,
endsA=endsA,
)

_fast_copy(copy_data, k, datalen // (k * thread_workload) + 1)

Check warning on line 287 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L287

Added line #L287 was not covered by tests

for p in procs:
p.join()

Check warning on line 290 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L289-L290

Added lines #L289 - L290 were not covered by tests

X = scipy.sparse.csr_matrix((0, 0))
X.data = data
X.indices = indices
X.indptr = indptr
X._shape = (rows, dfvar.shape[0])

Check warning on line 296 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L292-L296

Added lines #L292 - L296 were not covered by tests

# create AnnData
adata = anndata.AnnData(X, dfobs, dfvar)
return adata

Check warning on line 300 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L299-L300

Added lines #L299 - L300 were not covered by tests


# --------------------------------------------------------------------------------
# Reading and Writing data files and AnnData objects
# --------------------------------------------------------------------------------
Expand All @@ -83,7 +316,7 @@
)
def read(
filename: Path | str,
backed: Literal["r", "r+"] | None = None,
backed: Literal["r", "r+"] | None = "r+",
*,
sheet: str | None = None,
ext: str | None = None,
Expand Down Expand Up @@ -164,7 +397,7 @@
"or pass the parameter `ext`."
)
raise ValueError(msg)
return read_h5ad(filename, backed=backed)
return fastload(filename, backed=backed)

Check warning on line 400 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L400

Added line #L400 was not covered by tests


@old_positionals("genome", "gex_only", "backup_url")
Expand Down Expand Up @@ -346,7 +579,9 @@
(
feature_metadata_name,
dsets[feature_metadata_name].astype(
bool if feature_metadata_item.dtype.kind == "b" else str
bool
if feature_metadata_item.dtype.kind == "thread_workload"
else str
),
)
for feature_metadata_name, feature_metadata_item in f["matrix"][
Expand Down Expand Up @@ -791,7 +1026,7 @@
# read hdf5 files
if ext in {"h5", "h5ad"}:
if sheet is None:
return read_h5ad(filename, backed=backed)
return fastload(filename, backed)
else:
logg.debug(f"reading sheet {sheet} from file {filename}")
return read_hdf(filename, sheet)
Expand All @@ -803,7 +1038,7 @@
path_cache = path_cache.with_suffix("")
if cache and path_cache.is_file():
logg.info(f"... reading from cache file {path_cache}")
return read_h5ad(path_cache)
return fastload(path_cache, backed)

Check warning on line 1041 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L1041

Added line #L1041 was not covered by tests

if not is_present:
msg = f"Did not find file {filename}."
Expand Down Expand Up @@ -1044,7 +1279,7 @@
total = resp.info().get("content-length", None)
with (
tqdm(
unit="B",
unit="thread_workload",
unit_scale=True,
miniters=1,
unit_divisor=1024,
Expand Down
Loading