Skip to content

Commit

Permalink
[Lang] Support LU sparse solver on CUDA backend (taichi-dev#6967)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#2906 

### Brief Summary
To be consistent with API on CPU backend, this pr provides LU sparse
solver on CUDA backend. CuSolver just provides a CPU version API of LU
sparse solver which is used in this PR. The cuSolverRF provides a GPU
version LU solve, but it only supports `double` datatype. Thus, it's not
used in this PR.

Besides, the `print_triplets` is refactored to resolve the ndarray
`read` constraints (the `read` and `write` data should be the same
datatype).

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent e2da2ac commit fcb2edf
Show file tree
Hide file tree
Showing 10 changed files with 352 additions and 285 deletions.
6 changes: 5 additions & 1 deletion python/taichi/linalg/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ def _get_ndarray_addr(self):

def print_triplets(self):
"""Print the triplets stored in the builder"""
self.ptr.print_triplets()
taichi_arch = get_runtime().prog.config().arch
if taichi_arch == _ti_core.Arch.x64 or taichi_arch == _ti_core.Arch.arm64:
self.ptr.print_triplets_eigen()
elif taichi_arch == _ti_core.Arch.cuda:
self.ptr.print_triplets_cuda()

def build(self, dtype=f32, _format='CSR'):
"""Create a sparse matrix using the triplets"""
Expand Down
23 changes: 2 additions & 21 deletions python/taichi/linalg/sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,32 +100,13 @@ def solve(self, b): # pylint: disable=R1710
return self.solver.solve(b)
if isinstance(b, Ndarray):
x = ScalarNdarray(b.dtype, [self.matrix.m])
self.solve_rf(self.matrix, b, x)
self.solver.solve_rf(get_runtime().prog, self.matrix.matrix, b.arr,
x.arr)
return x
raise TaichiRuntimeError(
f"The parameter type: {type(b)} is not supported in linear solvers for now."
)

def solve_cu(self, sparse_matrix, b):
if isinstance(sparse_matrix, SparseMatrix) and isinstance(b, Ndarray):
x = ScalarNdarray(b.dtype, [sparse_matrix.m])
self.solver.solve_cu(get_runtime().prog, sparse_matrix.matrix,
b.arr, x.arr)
return x
raise TaichiRuntimeError(
f"The parameter type: {type(sparse_matrix)}, {type(b)} and {type(x)} is not supported in linear solvers for now."
)

def solve_rf(self, sparse_matrix, b, x):
if isinstance(sparse_matrix, SparseMatrix) and isinstance(
b, Ndarray) and isinstance(x, Ndarray):
self.solver.solve_rf(get_runtime().prog, sparse_matrix.matrix,
b.arr, x.arr)
else:
raise TaichiRuntimeError(
f"The parameter type: {type(sparse_matrix)}, {type(b)} and {type(x)} is not supported in linear solvers for now."
)

def info(self):
"""Check if the linear systems are solved successfully.
Expand Down
48 changes: 42 additions & 6 deletions taichi/program/sparse_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,53 @@ SparseMatrixBuilder::SparseMatrixBuilder(int rows,
prog_, dtype_, std::vector<int>{3 * (int)max_num_triplets_ + 1});
}

void SparseMatrixBuilder::print_triplets() {
num_triplets_ = ndarray_data_base_ptr_->read_int(std::vector<int>{0});
template <typename T, typename G>
void SparseMatrixBuilder::print_triplets_template() {
auto ptr = get_ndarray_data_ptr();
G *data = reinterpret_cast<G *>(ptr);
num_triplets_ = data[0];
fmt::print("n={}, m={}, num_triplets={} (max={})\n", rows_, cols_,
num_triplets_, max_num_triplets_);
data += 1;
for (int i = 0; i < num_triplets_; i++) {
auto idx = 3 * i + 1;
auto row = ndarray_data_base_ptr_->read_int(std::vector<int>{idx});
auto col = ndarray_data_base_ptr_->read_int(std::vector<int>{idx + 1});
auto val = ndarray_data_base_ptr_->read_float(std::vector<int>{idx + 2});
fmt::print("[{}, {}] = {}\n", data[i * 3], data[i * 3 + 1],
taichi_union_cast<T>(data[i * 3 + 2]));
}
}

void SparseMatrixBuilder::print_triplets_eigen() {
auto element_size = data_type_size(dtype_);
switch (element_size) {
case 4:
print_triplets_template<float32, int32>();
break;
case 8:
print_triplets_template<float64, int64>();
break;
default:
TI_ERROR("Unsupported sparse matrix data type!");
break;
}
}

void SparseMatrixBuilder::print_triplets_cuda() {
#ifdef TI_WITH_CUDA
CUDADriver::get_instance().memcpy_device_to_host(
&num_triplets_, (void *)get_ndarray_data_ptr(), sizeof(int));
fmt::print("n={}, m={}, num_triplets={} (max={})\n", rows_, cols_,
num_triplets_, max_num_triplets_);
auto len = 3 * num_triplets_ + 1;
std::vector<float32> trips(len);
CUDADriver::get_instance().memcpy_device_to_host(
(void *)trips.data(), (void *)get_ndarray_data_ptr(),
len * sizeof(float32));
for (auto i = 0; i < num_triplets_; i++) {
int row = taichi_union_cast<int>(trips[3 * i + 1]);
int col = taichi_union_cast<int>(trips[3 * i + 2]);
auto val = trips[i * 3 + 3];
fmt::print("[{}, {}] = {}\n", row, col, val);
}
#endif
}

intptr_t SparseMatrixBuilder::get_ndarray_data_ptr() const {
Expand Down
6 changes: 5 additions & 1 deletion taichi/program/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class SparseMatrixBuilder {
const std::string &storage_format,
Program *prog);

void print_triplets();
void print_triplets_eigen();
void print_triplets_cuda();

intptr_t get_ndarray_data_ptr() const;

Expand All @@ -36,6 +37,9 @@ class SparseMatrixBuilder {
template <typename T, typename G>
void build_template(std::unique_ptr<SparseMatrix> &);

template <typename T, typename G>
void print_triplets_template();

private:
uint64 num_triplets_{0};
std::unique_ptr<Ndarray> ndarray_data_base_ptr_{nullptr};
Expand Down
Loading

0 comments on commit fcb2edf

Please sign in to comment.