Skip to content

Commit

Permalink
remove hnsw and optimize sparse matrix multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
YQ-Wang committed Jun 22, 2024
1 parent 4a1bf5a commit 5e55a17
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 116 deletions.
10 changes: 2 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,11 @@ To ensure scBSP functions optimally, the following dependencies are required:
- SciPy (>= 1.10.1)
- scikit-learn (>=1.3.2)

For enhanced scBSP using HNSW for distance calculation:
- hnswlib (>= 0.8.0)

### Installation Commands
For Standard Installation (Using Ball Tree):

`pip install "scbsp"`

For Installation with HNSW (Hierarchical Navigable Small World Graphs):

`pip install "scbsp[hnsw]"`

For Installation with GPU:

`pip install "scbsp[gpu]"`
Expand All @@ -46,7 +39,8 @@ Additional parameters to specify include:

- `d1`: A floating-point number. Default value is 1.0.
- `d2`: A floating-point number. Default value is 3.0.
- `leaf_size`: Optional integer defining the maximum point threshold for the Ball Tree algorithm to revert to brute-force search (default = 80). Not required for installations using HNSW.
- `leaf_size`: Optional integer defining the maximum point threshold for the Ball Tree algorithm to revert to brute-force search (default = 80).
- `use_gpu`: Optional boolean defining whether to use the GPU (default = False).


### Example
Expand Down
1 change: 0 additions & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ numpy>=1.24.4
pandas>=1.3.5
scipy>=1.10.1
scikit-learn>=1.3.2
hnswlib>=0.8.0
torch>=1.10.0

# typing
Expand Down
42 changes: 2 additions & 40 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,9 @@
# pip-compile --output-file=requirements.txt requirements.in
#
filelock==3.13.4
# via
# torch
# triton
# via torch
fsspec==2024.3.1
# via torch
hnswlib==0.8.0
# via -r requirements.in
jinja2==3.1.3
# via torch
joblib==1.3.2
Expand All @@ -31,41 +27,9 @@ networkx==3.2.1
numpy==1.26.2
# via
# -r requirements.in
# hnswlib
# pandas
# scikit-learn
# scipy
nvidia-cublas-cu12==12.1.3.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.4.127
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
pandas==2.1.3
# via -r requirements.in
python-dateutil==2.8.2
Expand All @@ -86,10 +50,8 @@ threadpoolctl==3.4.0
# via scikit-learn
tomli==2.0.1
# via mypy
torch==2.3.0
torch==2.2.2
# via -r requirements.in
triton==2.3.0
# via torch
typing-extensions==4.8.0
# via
# mypy
Expand Down
105 changes: 40 additions & 65 deletions scbsp/scbsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,19 @@
@author: lijinp yiqingwang
This module utilizes a granularity-based dimension-agnostic tool, single-cell
big-small patch (scBSP), implementing sparse matrix operation and HNSW method
for distance calculation, for the identification of spatially variable genes
on large-scale data.
big-small patch (scBSP), implementing sparse matrix operation for distance
calculation, for the identification of spatially variable genes on
large-scale data.
"""

from typing import List, Tuple, Union
from typing import List, Union

import numpy as np
import pandas as pd # type: ignore
import scipy
import scipy # type: ignore
from scipy.sparse import csr_matrix, diags, identity, isspmatrix_csr # type: ignore
from scipy.stats import gmean, lognorm

try:
import torch
use_gpu = torch.cuda.is_available()
except ImportError:
use_gpu = False

try:
import hnswlib
use_hnsw = True
except ImportError:
from sklearn.neighbors import BallTree
use_hnsw = False
from scipy.stats import gmean, lognorm # type: ignore
from sklearn.neighbors import BallTree # type: ignore


def _scale_sparse_matrix(input_exp_mat: csr_matrix) -> csr_matrix:
Expand Down Expand Up @@ -78,22 +66,14 @@ def _binary_distance_matrix_threshold(
A csr_matrix representing the binary distance matrix.
"""

if use_hnsw is True:
labels, distances = _query_hnsw_index(input_sparse_mat_array)
within_d_val = distances <= d_val**2
rows = np.repeat(np.arange(labels.shape[0]), labels.shape[1])[
within_d_val.flatten()
]
cols = labels.flatten()[within_d_val.flatten()]
else:
ball_tree = BallTree(input_sparse_mat_array, leaf_size=leaf_size)
indices = ball_tree.query_radius(
input_sparse_mat_array, r=d_val, return_distance=False
)
rows = np.repeat(
np.arange(input_sparse_mat_array.shape[0]), [len(i) for i in indices]
)
cols = np.concatenate(indices)
ball_tree = BallTree(input_sparse_mat_array, leaf_size=leaf_size)
indices = ball_tree.query_radius(
input_sparse_mat_array, r=d_val, return_distance=False
)
rows = np.repeat(
np.arange(input_sparse_mat_array.shape[0]), [len(i) for i in indices]
)
cols = np.concatenate(indices)

# Construct binary csr_matrix
data = np.ones_like(rows)
Expand Down Expand Up @@ -126,34 +106,13 @@ def _calculate_sparse_variances(input_csr_mat: csr_matrix, axis: int) -> List[fl
return input_csr_mat_squared.mean(axis) - np.square(input_csr_mat.mean(axis))


def _query_hnsw_index(input_sp_mat: csr_matrix) -> Tuple[np.ndarray, np.ndarray]:
"""
Queries an HNSW index with a sparse matrix to find k-nearest neighbors.
Args:
input_sp_mat: The input sparse matrix.
Returns:
A tuple containing the labels and distances of the k-nearest neighbors.
"""

hnsw_index = hnswlib.Index(space="l2", dim=input_sp_mat.shape[1])
hnsw_index.init_index(max_elements=input_sp_mat.shape[0], ef_construction=200, M=16)
hnsw_index.add_items(input_sp_mat)
hnsw_index.set_ef(100) # ef should always be > k
labels, distances = hnsw_index.knn_query(
input_sp_mat, k=min(80, input_sp_mat.shape[0])
)

return labels, distances


def _get_test_scores(
input_sp_mat: np.ndarray,
input_exp_mat_raw: csr_matrix,
d1: float,
d2: float,
leaf_size: int,
use_gpu: bool,
) -> List[float]:
"""
Calculates test scores for genomic data based on input sparse matrices and distance thresholds.
Expand All @@ -164,6 +123,7 @@ def _get_test_scores(
d1: Distance threshold 1.
d2: Distance threshold 2.
leaf_size: An integer that determines the maximum number of points after which the Ball Tree algorithm opts for a brute-force search approach.
use_gpu: A boolean value that determines whether to use the GPU.
Returns:
A list of test scores.
Expand All @@ -182,6 +142,7 @@ def _var_local_means(
d_val: float,
input_exp_mat_norm: csr_matrix,
leaf_size: int,
use_gpu: bool
) -> List[float]:
patches_cells = _binary_distance_matrix_threshold(
input_sp_mat, d_val, leaf_size
Expand All @@ -197,25 +158,25 @@ def _var_local_means(

if use_gpu is True:
# Convert the csr_matrix to PyTorch tensors and move to GPU
input_exp_mat_norm_torch = torch.tensor(
input_exp_mat_norm_torch = torch.tensor( # type: ignore
input_exp_mat_norm.toarray(), device="cuda"
)
patches_cells_torch = torch.tensor(patches_cells.toarray(), device="cuda")
diag_matrix_sparse_torch = torch.tensor(
patches_cells_torch = torch.tensor(patches_cells.toarray(), device="cuda") # type: ignore
diag_matrix_sparse_torch = torch.tensor( # type: ignore
diag_matrix_sparse.toarray(), device="cuda"
)

result = torch.matmul(
result = torch.matmul( # type: ignore
input_exp_mat_norm_torch,
torch.matmul(patches_cells_torch, diag_matrix_sparse_torch),
torch.matmul(patches_cells_torch, diag_matrix_sparse_torch), # type: ignore
)
x_kj = scipy.sparse.csr_matrix(result.cpu().numpy())
else:
x_kj = input_exp_mat_norm.dot(patches_cells.dot(diag_matrix_sparse))
x_kj = input_exp_mat_norm @ (patches_cells @ diag_matrix_sparse)

return _calculate_sparse_variances(x_kj, axis=1)

var_x = np.column_stack([_var_local_means(input_sp_mat, d_val, input_exp_mat_norm, leaf_size).A.ravel() for d_val in (d1, d2)]) # type: ignore
var_x = np.column_stack([_var_local_means(input_sp_mat, d_val, input_exp_mat_norm, leaf_size, use_gpu).A.ravel() for d_val in (d1, d2)]) # type: ignore
var_x_0_add = _calculate_sparse_variances(input_exp_mat_raw, axis=1).A.ravel() # type: ignore
var_x_0_add /= max(var_x_0_add)
t_matrix = (var_x[:, 1] / var_x[:, 0]) * var_x_0_add
Expand All @@ -228,6 +189,7 @@ def granp(
d1: float = 1.0,
d2: float = 3.0,
leaf_size: int = 80,
use_gpu: bool = False
) -> pd.DataFrame:
"""
Calculates the p-values for genomic data.
Expand All @@ -238,10 +200,23 @@ def granp(
d1: Distance threshold 1.
d2: Distance threshold 2.
leaf_size: An integer that determines the maximum number of points after which the Ball Tree algorithm opts for a brute-force search approach.
use_gpu: A boolean value that determines whether to use the GPU.
Returns:
A Pandas DataFrame with columns ['gene_names', 'p_values'].
"""

# Check if GPU should be used and if it's available
if use_gpu is True:
try:
import torch # type: ignore
if not torch.cuda.is_available():
print("CUDA is not available, setting use_gpu to False.")
use_gpu = False
except ImportError:
print("Torch is not available, setting use_gpu to False.")
use_gpu = False

# Extract column names if input_exp_mat_raw is a Pandas DataFrame, else use indices
if isinstance(input_exp_mat_raw, pd.DataFrame):
gene_names = input_exp_mat_raw.columns.astype(str).tolist()
Expand All @@ -266,7 +241,7 @@ def granp(
d1 *= scale_factor
d2 *= scale_factor

t_matrix_sum = _get_test_scores(input_sp_mat, input_exp_mat_raw, d1, d2, leaf_size)
t_matrix_sum = _get_test_scores(input_sp_mat, input_exp_mat_raw, d1, d2, leaf_size, use_gpu)

# Calculate p-values
t_matrix_sum_upper90 = np.quantile(t_matrix_sum, 0.90)
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
],
install_requires=["numpy >= 1.24.4", "pandas >= 1.3.5", "scipy >= 1.10.1", "scikit-learn >= 1.3.2"],
extras_require={
"hnsw": ["hnswlib >= 0.8.0"],
"gpu": ["torch >= 1.10.0"],
},
python_requires=">=3.8",
Expand Down
3 changes: 2 additions & 1 deletion test/test_scbsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def test_non_empty_matrices(self):
d1 = 1.0
d2 = 3.0
leaf_size = 80
use_gpu = False

result = _get_test_scores(input_sp_mat, input_exp_mat_raw, d1, d2, leaf_size)
result = _get_test_scores(input_sp_mat, input_exp_mat_raw, d1, d2, leaf_size, use_gpu)

# Check if the result is a numpy.ndarray
self.assertIsInstance(result, list)
Expand Down

0 comments on commit 5e55a17

Please sign in to comment.