diff --git a/src/scanpy/readwrite.py b/src/scanpy/readwrite.py index c568519cd7..15cb65fda7 100644 --- a/src/scanpy/readwrite.py +++ b/src/scanpy/readwrite.py @@ -6,6 +6,7 @@ from pathlib import Path, PurePath from typing import TYPE_CHECKING +import anndata import anndata.utils import h5py import numpy as np @@ -32,6 +33,12 @@ read_mtx, read_text, ) +import multiprocessing as mp +import threading +from dataclasses import dataclass + +import numba +import scipy from anndata import AnnData from matplotlib.image import imread @@ -46,6 +53,14 @@ from ._utils import Empty +indices_type = np.int64 +indices_shm_type = "l" + +semDataLoaded = None # will be initialized later +semDataCopied = None # will be initialized later + +thread_workload = 4000000 # experimented value + # .gz and .bz2 suffixes are also allowed for text formats text_exts = { "csv", @@ -67,6 +82,224 @@ """Available file formats for reading data. """ +def get_1d_index(row: int, col: int, num_cols: int) -> int: + """ + Convert 2D coordinates to 1D index. + + Parameters: + row (int): Row index in the 2D array. + col (int): Column index in the 2D array. + num_cols (int): Number of columns in the 2D array. + + Returns: + int: Corresponding 1D index. + """ + return row * num_cols + col + + +@dataclass +class LoadHelperData: + i: int + k: int + datalen: int + dataArray: mp.Array + indicesArray: mp.Array + startsArray: mp.Array + endsArray: mp.Array + + +def _load_helper(fname: str, helper_data: LoadHelperData): + i = helper_data.i + k = helper_data.k + datalen = helper_data.datalen + dataArray = helper_data.dataArray + indicesArray = helper_data.indicesArray + startsArray = helper_data.startsArray + endsArray = helper_data.endsArray + + f = h5py.File(fname, "r") + dataA = np.frombuffer(dataArray, dtype=np.float32) + indicesA = np.frombuffer(indicesArray, dtype=indices_type) + startsA = np.frombuffer(startsArray, dtype=np.int64) + endsA = np.frombuffer(endsArray, dtype=np.int64) + for j in range(datalen // (k * thread_workload) + 1): + # compute start, end + s = i * datalen // k + j * thread_workload + e = min(s + thread_workload, (i + 1) * datalen // k) + length = e - s + startsA[i] = s + endsA[i] = e + # read direct + f["X"]["data"].read_direct( + dataA, np.s_[s:e], np.s_[i * thread_workload : i * thread_workload + length] + ) + f["X"]["indices"].read_direct( + indicesA, + np.s_[s:e], + np.s_[i * thread_workload : i * thread_workload + length], + ) + + # coordinate with copy threads + semDataLoaded[i].release() # done data load + semDataCopied[i].acquire() # wait until data copied + + +def _waitload(i): + semDataLoaded[i].acquire() + + +def _signalcopy(i): + semDataCopied[i].release() + + +@dataclass +class CopyData: + data: np.ndarray + dataA: np.ndarray + indices: np.ndarray + indicesA: np.ndarray + startsA: np.ndarray + endsA: np.ndarray + + +def _fast_copy(copy_data: CopyData, k: int, m: int): + # Access the arrays through copy_data + data = copy_data.data + dataA = copy_data.dataA + indices = copy_data.indices + indicesA = copy_data.indicesA + starts = copy_data.startsA + ends = copy_data.endsA + + def thread_fun(i, m): + for j in range(m): + with numba.objmode(): + _waitload(i) + length = ends[i] - starts[i] + data[starts[i] : ends[i]] = dataA[ + i * thread_workload : i * thread_workload + length + ] + indices[starts[i] : ends[i]] = indicesA[ + i * thread_workload : i * thread_workload + length + ] + with numba.objmode(): + _signalcopy(i) + + threads = [threading.Thread(target=thread_fun, args=(i, m)) for i in range(k)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + +def fastload(fname, backed): + f = h5py.File(fname, backed) + assert "X" in f, "'X' is missing from f" + assert "var" in f, "'var' is missing from f" + assert "obs" in f, "'obs' is missing from f" + + if not scipy.sparse.issparse(f["X"]): + f.close() + return read_h5ad(fname, backed=backed) + + # get obs dataframe + rows = f["obs"][list(f["obs"].keys())[0]].size + # load index pointers, prepare shared arrays + indptr = f["X"]["indptr"][0 : rows + 1] + datalen = int(indptr[-1]) + + if datalen < thread_workload: + f.close() + return read_h5ad(fname, backed=backed) + if "_index" in f["obs"]: + dfobsind = pd.Series(f["obs"]["_index"].asstr()[0:rows]) + dfobs = pd.DataFrame(index=dfobsind) + else: + dfobs = pd.DataFrame() + for k in f["obs"]: + if k == "_index": + continue + dfobs[k] = f["obs"][k].asstr()[...] + + # get var dataframe + if "_index" in f["var"]: + dfvarind = pd.Series(f["var"]["_index"].asstr()[...]) + dfvar = pd.DataFrame(index=dfvarind) + else: + dfvar = pd.DataFrame() + for k in f["var"]: + if k == "_index": + continue + dfvar[k] = f["var"][k].asstr()[...] + + f.close() + k = numba.get_num_threads() + dataArray = mp.Array( + "f", k * thread_workload, lock=False + ) # should be in shared memory + indicesArray = mp.Array( + indices_shm_type, k * thread_workload, lock=False + ) # should be in shared memory + startsArray = mp.Array("l", k, lock=False) # start index of data read + endsArray = mp.Array("l", k, lock=False) # end index (noninclusive) of data read + global semDataLoaded + global semDataCopied + semDataLoaded = [mp.Semaphore(0) for _ in range(k)] + semDataCopied = [mp.Semaphore(0) for _ in range(k)] + dataA = np.frombuffer(dataArray, dtype=np.float32) + indicesA = np.frombuffer(indicesArray, dtype=indices_type) + startsA = np.frombuffer(startsArray, dtype=np.int64) + endsA = np.frombuffer(endsArray, dtype=np.int64) + data = np.empty(datalen, dtype=np.float32) + indices = np.empty(datalen, dtype=indices_type) + + procs = [ + mp.Process( + target=_load_helper, + args=( + fname, + LoadHelperData( + i=i, + k=k, + datalen=datalen, + dataArray=dataArray, + indicesArray=indicesArray, + startsArray=startsArray, + endsArray=endsArray, + ), + ), + ) + for i in range(k) + ] + + for p in procs: + p.start() + + copy_data = CopyData( + data=data, + dataA=dataA, + indices=indices, + indicesA=indicesA, + startsA=startsA, + endsA=endsA, + ) + + _fast_copy(copy_data, k, datalen // (k * thread_workload) + 1) + + for p in procs: + p.join() + + X = scipy.sparse.csr_matrix((0, 0)) + X.data = data + X.indices = indices + X.indptr = indptr + X._shape = (rows, dfvar.shape[0]) + + # create AnnData + adata = anndata.AnnData(X, dfobs, dfvar) + return adata + + # -------------------------------------------------------------------------------- # Reading and Writing data files and AnnData objects # -------------------------------------------------------------------------------- @@ -83,7 +316,7 @@ ) def read( filename: Path | str, - backed: Literal["r", "r+"] | None = None, + backed: Literal["r", "r+"] | None = "r+", *, sheet: str | None = None, ext: str | None = None, @@ -164,7 +397,7 @@ def read( "or pass the parameter `ext`." ) raise ValueError(msg) - return read_h5ad(filename, backed=backed) + return fastload(filename, backed=backed) @old_positionals("genome", "gex_only", "backup_url") @@ -346,7 +579,9 @@ def _read_v3_10x_h5(filename, *, start=None): ( feature_metadata_name, dsets[feature_metadata_name].astype( - bool if feature_metadata_item.dtype.kind == "b" else str + bool + if feature_metadata_item.dtype.kind == "thread_workload" + else str ), ) for feature_metadata_name, feature_metadata_item in f["matrix"][ @@ -791,7 +1026,7 @@ def _read( # read hdf5 files if ext in {"h5", "h5ad"}: if sheet is None: - return read_h5ad(filename, backed=backed) + return fastload(filename, backed) else: logg.debug(f"reading sheet {sheet} from file {filename}") return read_hdf(filename, sheet) @@ -803,7 +1038,7 @@ def _read( path_cache = path_cache.with_suffix("") if cache and path_cache.is_file(): logg.info(f"... reading from cache file {path_cache}") - return read_h5ad(path_cache) + return fastload(path_cache, backed) if not is_present: msg = f"Did not find file {filename}." @@ -1044,7 +1279,7 @@ def _download(url: str, path: Path): total = resp.info().get("content-length", None) with ( tqdm( - unit="B", + unit="thread_workload", unit_scale=True, miniters=1, unit_divisor=1024,