diff --git a/examples/cusparse/CMakeLists.txt b/examples/cusparse/CMakeLists.txt index 829f44b..436e67d 100644 --- a/examples/cusparse/CMakeLists.txt +++ b/examples/cusparse/CMakeLists.txt @@ -4,3 +4,4 @@ function(add_cuda_example example_name) endfunction() add_cuda_example(cusparse_simple_spmv) +add_cuda_example(cusparse_simple_spmm) diff --git a/examples/cusparse/cusparse_simple_spmm.cpp b/examples/cusparse/cusparse_simple_spmm.cpp new file mode 100644 index 0000000..c785732 --- /dev/null +++ b/examples/cusparse/cusparse_simple_spmm.cpp @@ -0,0 +1,119 @@ +#include +#include + +#include + +#include "util.hpp" + +#include +#include + +int main(int argc, char** argv) { + namespace md = spblas::__mdspan; + + using value_t = float; + using index_t = spblas::index_t; + using offset_t = spblas::offset_t; + + spblas::index_t m = 100; + spblas::index_t n = 10; + spblas::index_t k = 100; + spblas::index_t nnz_in = 10; + + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n\t### Running SpMM Example:"); + fmt::print("\n\t###"); + fmt::print("\n\t### Y = alpha * A * X"); + fmt::print("\n\t###"); + fmt::print("\n\t### with "); + fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, n, + nnz_in); + fmt::print("\n\t### X, a dense matrix, of size ({}, {})", n, k); + fmt::print("\n\t### Y, a dense matrix, of size ({}, {})", m, k); + fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)", + sizeof(spblas::index_t)); + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n"); + + auto&& [values, rowptr, colind, shape, nnz] = + spblas::generate_csr(m, n, nnz_in); + + value_t* d_values; + offset_t* d_rowptr; + index_t* d_colind; + + CUDA_CHECK(cudaMalloc(&d_values, values.size() * sizeof(value_t))); + CUDA_CHECK(cudaMalloc(&d_rowptr, rowptr.size() * sizeof(offset_t))); + CUDA_CHECK(cudaMalloc(&d_colind, colind.size() * sizeof(index_t))); + + CUDA_CHECK(cudaMemcpy(d_values, values.data(), + values.size() * sizeof(value_t), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(d_rowptr, rowptr.data(), + rowptr.size() * sizeof(offset_t), cudaMemcpyDefault)); + CUDA_CHECK(cudaMemcpy(d_colind, colind.data(), + colind.size() * sizeof(index_t), cudaMemcpyDefault)); + + spblas::csr_view a(d_values, d_rowptr, d_colind, + shape, nnz); + + // Scale every value of `a` by 5 in place. + // scale(5.f, a); + + std::vector x(n * k, 1); + std::vector y(m * k, 0); + + value_t* d_x; + value_t* d_y; + + CUDA_CHECK(cudaMalloc(&d_x, x.size() * sizeof(value_t))); + CUDA_CHECK(cudaMalloc(&d_y, y.size() * sizeof(value_t))); + + CUDA_CHECK( + cudaMemcpy(d_x, x.data(), x.size() * sizeof(value_t), cudaMemcpyDefault)); + CUDA_CHECK( + cudaMemcpy(d_y, y.data(), y.size() * sizeof(value_t), cudaMemcpyDefault)); + + md::mdspan x_span(d_x, n, k); + md::mdspan y_span(d_y, m, k); + + // Y = A * X + spblas::operation_info_t info; + spblas::multiply(info, a, x_span, y_span); + + CUDA_CHECK( + cudaMemcpy(y.data(), d_y, y.size() * sizeof(value_t), cudaMemcpyDefault)); + + // CPU reference + std::vector y_ref(m * k, 0); + for (index_t i = 0; i < m; i++) { + for (offset_t j = rowptr[i]; j < rowptr[i + 1]; j++) { + index_t col = colind[j]; + value_t val = values[j]; + for (index_t l = 0; l < k; l++) { + y_ref[i * k + l] += val * x[col * k + l]; + } + } + } + + bool failed = false; + + for (size_t i = 0; i < y.size(); ++i) { + if (y[i] != y_ref[i]) { + fprintf(stderr, "Value mismatch at index %ld: y_ref[%ld] = %f, y[%ld] = %f\n", i, i, y_ref[i], i, y_ref[i]); + failed = true; + } + } + + if (failed) { + fmt::print("\tValidation failed!\n"); + } + else { + fmt::print("\tValidation succeeded!\n"); + } + + fmt::print("\tExample is completed!\n"); + + return failed; +} diff --git a/include/spblas/vendor/cusparse/detail/create_matrix_handle.hpp b/include/spblas/vendor/cusparse/detail/create_matrix_handle.hpp new file mode 100644 index 0000000..72290e6 --- /dev/null +++ b/include/spblas/vendor/cusparse/detail/create_matrix_handle.hpp @@ -0,0 +1,77 @@ +#pragma once + +#include + +#include + +#include + +namespace spblas { + +namespace __cusparse { + +template + requires __detail::is_csr_view_v +cusparseSpMatDescr_t create_matrix_handle(M&& m) { + cusparseSpMatDescr_t mat_descr; + __cusparse::throw_if_error(cusparseCreateCsr( + &mat_descr, __backend::shape(m)[0], __backend::shape(m)[1], + m.values().size(), m.rowptr().data(), m.colind().data(), + m.values().data(), detail::cusparse_index_type_v>, + detail::cusparse_index_type_v>, + CUSPARSE_INDEX_BASE_ZERO, detail::cuda_data_type_v>)); + + return mat_descr; +} + +template + requires __detail::is_csc_view_v +cusparseSpMatDescr_t create_matrix_handle(M&& m) { + cusparseSpMatDescr_t mat_descr; + __cusparse::throw_if_error(cusparseCreateCsc( + &mat_descr, __backend::shape(m)[0], __backend::shape(m)[1], + m.values().size(), m.rowptr().data(), m.colind().data(), + m.values().data(), detail::cusparse_index_type_v>, + detail::cusparse_index_type_v>, + CUSPARSE_INDEX_BASE_ZERO, detail::cuda_data_type_v>)); + + return mat_descr; +} + +template + requires __detail::has_base +cusparseSpMatDescr_t create_matrix_handle(M&& m) { + return create_matrix_handle(m.base()); +} + +// +// Takes in a CSR or CSR_transpose (aka CSC) or CSC or CSC_transpose +// and returns the transpose value associated with it being represented +// in the CSR format (since oneMKL SYCL currently does not have CSC +// format +// +// CSR = CSR + nontrans +// CSR_transpose = CSR + trans +// CSC = CSR + trans +// CSC_transpose -> CSR + nontrans +// +// template +// oneapi::mkl::transpose get_transpose(M&& m) { +// static_assert(__detail::has_csr_base || __detail::has_csc_base); + +// const bool conjugate = __detail::is_conjugated(m); +// if constexpr (__detail::has_csr_base) { +// if (conjugate) { +// throw std::runtime_error( +// "oneMKL SYCL backend does not support conjugation for CSR views."); +// } +// return oneapi::mkl::transpose::nontrans; +// } else if constexpr (__detail::has_csc_base) { +// return conjugate ? oneapi::mkl::transpose::conjtrans +// : oneapi::mkl::transpose::trans; +// } +// } + +} // namespace __cusparse + +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/detail/detail.hpp b/include/spblas/vendor/cusparse/detail/detail.hpp new file mode 100644 index 0000000..da2fc72 --- /dev/null +++ b/include/spblas/vendor/cusparse/detail/detail.hpp @@ -0,0 +1,4 @@ +#pragma once + +#include "create_matrix_handle.hpp" +#include "get_matrix_handle.hpp" diff --git a/include/spblas/vendor/cusparse/detail/get_matrix_handle.hpp b/include/spblas/vendor/cusparse/detail/get_matrix_handle.hpp new file mode 100644 index 0000000..cec7b22 --- /dev/null +++ b/include/spblas/vendor/cusparse/detail/get_matrix_handle.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +namespace spblas { + +namespace __cusparse { + +template +cusparseSpMatDescr_t +get_matrix_handle(M&& m, + cusparseSpMatDescr_t handle = nullptr) { + if constexpr (__detail::is_matrix_opt_v) { + log_trace("using A as matrix_opt"); + + if (m.matrix_handle_ == nullptr) { + m.matrix_handle_ = create_matrix_handle(m.base()); + } + + return m.matrix_handle_; + } else if constexpr (__detail::has_base) { + return get_matrix_handle(m.base(), handle); + } else if (handle != nullptr) { + log_trace("using A from operation_info_t"); + + return handle; + } else { + log_trace("using A as csr_base"); + + return create_matrix_handle(m); + } +} + +} // namespace __cusparse + +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/detail/spgemm_state_t.hpp b/include/spblas/vendor/cusparse/detail/spgemm_state_t.hpp new file mode 100644 index 0000000..91864ad --- /dev/null +++ b/include/spblas/vendor/cusparse/detail/spgemm_state_t.hpp @@ -0,0 +1,65 @@ +#pragma once + +#include +#include + +#include "abstract_operation_state.hpp" + +namespace spblas { +namespace __cusparse { + +class spgemm_state_t : public abstract_operation_state_t { +public: + spgemm_state_t() = default; + ~spgemm_state_t() { + if (a_descr_) { + cusparseDestroySpMat(a_descr_); + } + if (b_descr_) { + cusparseDestroySpMat(b_descr_); + } + if (c_descr_) { + cusparseDestroySpMat(c_descr_); + } + if (spgemm_descr_) { + cusparseSpGEMM_destroyDescr(spgemm_descr_); + } + } + + // Accessors for the descriptors + cusparseSpMatDescr_t a_descriptor() const { + return a_descr_; + } + cusparseDnVecDescr_t b_descriptor() const { + return b_descr_; + } + cusparseDnVecDescr_t c_descriptor() const { + return c_descr_; + } + cusparseSpGEMMDescr_t spgemm_descriptor() const { + return spgemm_descr_; + } + + // Setters for the descriptors + void set_a_descriptor(cusparseSpMatDescr_t descr) { + a_descr_ = descr; + } + void set_b_descriptor(cusparseDnVecDescr_t descr) { + b_descr_ = descr; + } + void set_c_descriptor(cusparseDnVecDescr_t descr) { + c_descr_ = descr; + } + void set_spgemm_descriptor(cusparseSpGEMMDescr_t descr) { + spgemm_descr_ = descr; + } + +private: + cusparseSpMatDescr_t a_descr_ = nullptr; + cusparseSpMatDescr_t b_descr_ = nullptr; + cusparseSpMatDescr_t c_descr_ = nullptr; + cusparseSpGEMMDescr_t spgemm_descr_ = nullptr; +}; + +} // namespace __cusparse +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/detail/spmm_state_t.hpp b/include/spblas/vendor/cusparse/detail/spmm_state_t.hpp new file mode 100644 index 0000000..24af221 --- /dev/null +++ b/include/spblas/vendor/cusparse/detail/spmm_state_t.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include +#include + +#include "abstract_operation_state.hpp" + +namespace spblas { +namespace __cusparse { + +class spmm_state_t : public abstract_operation_state_t { +public: + spmm_state_t() = default; + ~spmm_state_t() { + if (a_descr_) { + cusparseDestroySpMat(a_descr_); + } + if (x_descr_) { + cusparseDestroyDnMat(x_descr_); + } + if (y_descr_) { + cusparseDestroyDnMat(y_descr_); + } + } + + // Accessors for the descriptors + cusparseSpMatDescr_t a_descriptor() const { + return a_descr_; + } + cusparseDnMatDescr_t x_descriptor() const { + return x_descr_; + } + cusparseDnMatDescr_t y_descriptor() const { + return y_descr_; + } + + // Setters for the descriptors + void set_a_descriptor(cusparseSpMatDescr_t descr) { + a_descr_ = descr; + } + void set_x_descriptor(cusparseDnMatDescr_t descr) { + x_descr_ = descr; + } + void set_y_descriptor(cusparseDnMatDescr_t descr) { + y_descr_ = descr; + } + +private: + cusparseSpMatDescr_t a_descr_ = nullptr; + cusparseDnMatDescr_t x_descr_ = nullptr; + cusparseDnMatDescr_t y_descr_ = nullptr; +}; + +} // namespace __cusparse +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/multiply.hpp b/include/spblas/vendor/cusparse/multiply.hpp index 4b68a98..1785448 100644 --- a/include/spblas/vendor/cusparse/multiply.hpp +++ b/include/spblas/vendor/cusparse/multiply.hpp @@ -1,3 +1,4 @@ #pragma once #include "spmv_impl.hpp" +#include "spmm_impl.hpp" diff --git a/include/spblas/vendor/cusparse/spgemm_impl.hpp b/include/spblas/vendor/cusparse/spgemm_impl.hpp new file mode 100644 index 0000000..dfaa51a --- /dev/null +++ b/include/spblas/vendor/cusparse/spgemm_impl.hpp @@ -0,0 +1,209 @@ +#pragma once + +#include + +#include + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +// +// Defines the following APIs for SpGEMM: +// +// C = op(A) * op(B) +// +// where A,B and C are sparse matrices of CSR format +// +// operation_info_t multiply_inspect(A, B, C) +// void multiply_compute(operation_info_t, A, B, C) +// + +namespace spblas { + +template + requires(__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csr_view_v +operation_info_t + multiply_compute(ExecutionPolicy&& policy, A&& a, B&& b, C&& c) { + log_trace(""); + + if (__detail::is_conjugated(c)) { + throw std::runtime_error( + "cusparse backend does not support conjugated output matrices."); + } + + // Get or create state + auto state = info.state_.get_state<__cusparse::spgemm_state_t>(); + if (!state) { + info.state_ = __cusparse::operation_state_t( + std::make_unique<__cusparse::spgemm_state_t>()); + state = info.state_.get_state<__cusparse::spgemm_state_t>(); + } + + auto handle = state->handle(); + + // Create or get matrix descriptors + auto a_handle = __cusparse::get_matrix_handle(a); + auto b_handle = __cusparse::get_matrix_handle(b); + auto c_handle = __cusparse::get_matrix_handle(c); + + // Get operation type based on matrix format + auto a_transpose = __cusparse::get_transpose(a); + auto b_transpose = __cusparse::get_transpose(b); + + auto alpha_optional = __detail::get_scaling_factor(a, x); + tensor_scalar_t alpha = alpha_optional.value_or(1); + tensor_scalar_t beta = 0; + + using T = tensor_scalar_t; + using I = tensor_index_t; + using O = tensor_offset_t; + + O* c_rowptr; + if (c.rowptr().size() >= __backend::shape(c)[0] + 1) { + c_rowptr = c.rowptr().data(); + } else { + cudaMalloc(&c_rowptr, (__backend::shape(c)[0] + 1) * sizeof(O)); + } + + // Create SpGEMM descriptor + cusparseSpGEMMDescr_t spgemm_descr; + __cusparse::throw_if_error(cusparseSpGEMM_createDescr(&spgemm_descr)); + + // Work estimation (get buffer size) + size_t bufferSize1 = 0; + __cusparse::throw_if_error(cusparseSpGEMM_workEstimation( + handle, a_transpose, b_transpose, &alpha, a_handle, b_handle, + &beta, c_handle, detail::cuda_data_type_v, CUSPARSE_SPGEMM_DEFAULT, + spgemm_descr, &bufferSize1, nullptr)); + + void* buffer1 = nullptr; + if (bufferSize1 > 0) { + cudaMalloc(&buffer1, bufferSize1); + } + + // Work estimation (execute) + __cusparse::throw_if_error(cusparseSpGEMM_workEstimation( + handle, a_transpose, b_transpose, &alpha, a_handle, b_handle, + &beta, c_handle, detail::cuda_data_type_v, CUSPARSE_SPGEMM_DEFAULT, + spgemm_descr, &bufferSize1, buffer1)); + + // Compute (get buffer size) + size_t bufferSize2 = 0; + __cusparse::throw_if_error(cusparseSpGEMM_compute( + handle, a_transpose, b_transpose, &alpha, a_handle, b_handle, + &beta, c_handle, detail::cuda_data_type_v, CUSPARSE_SPGEMM_DEFAULT, + spgemm_descr, &bufferSize2, nullptr)); + + void* buffer2 = nullptr; + if (bufferSize2 > 0) { + cudaMalloc(&buffer2, bufferSize2); + } + + // Compute (execute) + __cusparse::throw_if_error(cusparseSpGEMM_compute( + handle, a_transpose, b_transpose, &alpha, a_handle, b_handle, + &beta, c_handle, detail::cuda_data_type_v, CUSPARSE_SPGEMM_DEFAULT, + spgemm_descr, &bufferSize2, buffer2)); + + // Get output nnz + size_t C_rows, C_cols, C_nnz; + __cusparse::throw_if_error( + cusparseSpMatGetSize(c_handle, &C_rows, &C_cols, &C_nnz)); + + log_info("computed c_nnz = %ld", C_nnz); + + return operation_info_t{ + index<>{__backend::shape(c)[0], __backend::shape(c)[1]}, nnz, + __mkl::operation_state_t{__detail::has_matrix_opt(a) ? nullptr : a_handle, + __detail::has_matrix_opt(b) ? nullptr : b_handle, + c_handle, nullptr, descr, (void*) c_rowptr, q}}; +} + +template + requires(__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csr_view_v +void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + B&& b, C&& c) { + log_trace(""); + + if (__detail::is_conjugated(c)) { + throw std::runtime_error( + "cusparse backend does not support conjugated output matrices."); + } + + using T = tensor_scalar_t; + using O = tensor_offset_t; + + auto state = info.state_.get_state<__cusparse::spgemm_state_t>(); + + auto handle = state->handle(); + + // Get matrix descriptors + auto a_handle = __cusparse::get_matrix_handle(a); + auto b_handle = __cusparse::get_matrix_handle(b); + auto c_handle = __cusparse::get_matrix_handle(c); + auto spgemm_descr = state->spgemm_descriptor(); + + // Get operation type based on matrix format + auto a_transpose = __cusparse::get_transpose(a); + auto b_transpose = __cusparse::get_transpose(b); + + auto alpha_optional = __detail::get_scaling_factor(a, b); + tensor_scalar_t alpha = alpha_optional.value_or(1); + T beta = 0; + + // Update C descriptor with the now-allocated colind and values + O* c_rowptr = static_cast(state->c_rowptr()); + __cusparse::throw_if_error(cusparseCsrSetPointers( + c_handle, c_rowptr, c.colind().data(), c.values().data())); + + // Copy computed results into C's arrays + __cusparse::throw_if_error(cusparseSpGEMM_copy( + handle, a_transpose, b_transpose, + &alpha, a_handle, b_handle, &beta, c_handle, + detail::cuda_data_type_v, CUSPARSE_SPGEMM_DEFAULT, spgemm_descr)); + + if (c_rowptr != c.rowptr().data()) { + cudaMemcpy(c.rowptr().data(), c_rowptr, + sizeof(O) * (__backend::shape(c)[0] + 1)) + .wait(); + } + + if (alpha_optional.has_value()) { + scale(alpha, c); + } +} + +template + requires(__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csc_view_v +operation_info_t multiply_compute(A&& a, B&& b, C&& c) { + return multiply_compute(transposed(b), transposed(a), transposed(c)); +} + +template + requires((__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csc_view_v) +void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c) { + multiply_fill(info, transposed(b), transposed(a), transposed(c)); +} + +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/spmm_impl.hpp b/include/spblas/vendor/cusparse/spmm_impl.hpp new file mode 100644 index 0000000..9734c12 --- /dev/null +++ b/include/spblas/vendor/cusparse/spmm_impl.hpp @@ -0,0 +1,119 @@ +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// +// Defines the following APIs for SpMM: +// +// Y = alpha * op(A) * X +// +// where A is a sparse matrices of CSR format and +// X/Y are dense matrices of row_major format +// +// //operation_info_t multiply_inspect(A, x, y) +// //void multiply_inspect(operation_info_t, A, x, y) +// +// //void multiply_compute(operation_info_t, A, x, y) +// void multiply(A, x, y) +// + +namespace spblas { + +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { + log_trace(""); + + auto x_base = __detail::get_ultimate_base(x); + auto y_base = __detail::get_ultimate_base(y); + + if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { + throw std::runtime_error( + "cusparse backend does not support conjugated dense matrices."); + } + + auto alpha_optional = __detail::get_scaling_factor(a, x); + tensor_scalar_t alpha = alpha_optional.value_or(1); + tensor_scalar_t beta = 0; + + // Get or create state + auto state = info.state_.get_state<__cusparse::spmm_state_t>(); + if (!state) { + info.state_ = __cusparse::operation_state_t( + std::make_unique<__cusparse::spmm_state_t>()); + state = info.state_.get_state<__cusparse::spmm_state_t>(); + } + + auto a_handle = __cusparse::get_matrix_handle(a); + auto a_transpose = __cusparse::get_transpose(a); + + cusparseConstDnMatDescr_t x_handle; + cusparseDnMatDescr_t y_handle; + + __cusparse::throw_if_error(cusparseCreateConstDnMat(&x_handle, x_base.extent(0), + x_base.extent(1), x_base.extent(1), x_base.data_handle(), + detail::cuda_data_type_v>, CUSPARSE_ORDER_ROW)); + + __cusparse::throw_if_error(cusparseCreateDnMat(&y_handle, y_base.extent(0), + y_base.extent(1), y_base.extent(1), y_base.data_handle(), + detail::cuda_data_type_v>, CUSPARSE_ORDER_ROW)); + + // Get buffer size + size_t buffer_size; + __cusparse::throw_if_error(cusparseSpMM_bufferSize( + state->handle(), a_transpose, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, a_handle, + x_handle, &beta, y_handle, detail::cuda_data_type_v>, + CUSPARSE_SPMM_ALG_DEFAULT, &buffer_size)); + + // Allocate buffer if needed + void* buffer = nullptr; + if (buffer_size > 0) { + cudaMalloc(&buffer, buffer_size); + } + + // Execute SpMM + __cusparse::throw_if_error(cusparseSpMM( + state->handle(), a_transpose, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, a_handle, + x_handle, &beta, y_handle, detail::cuda_data_type_v>, + CUSPARSE_SPMM_ALG_DEFAULT, buffer)); + + // Free buffer if allocated + if (buffer) { + cudaFree(buffer); + } +} + +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +void multiply(A&& a, X&& x, Y&& y) { + operation_info_t info; + multiply(info, std::forward(a), std::forward(x), std::forward(y)); +} + +} // namespace spblas diff --git a/include/spblas/vendor/cusparse/spmv_impl.hpp b/include/spblas/vendor/cusparse/spmv_impl.hpp index bcbc68d..0e297ea 100644 --- a/include/spblas/vendor/cusparse/spmv_impl.hpp +++ b/include/spblas/vendor/cusparse/spmv_impl.hpp @@ -14,6 +14,8 @@ #include #include +#include + namespace spblas { template @@ -48,16 +50,11 @@ void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { } // Create or get matrix descriptor - if (!state->a_descriptor()) { - cusparseSpMatDescr_t a_descr = __cusparse::create_cusparse_handle(a_base); - state->set_a_descriptor(a_descr); - } + auto a_handle = __cusparse::get_matrix_handle(a); // Create vector descriptors - cusparseDnVecDescr_t b_descr = __cusparse::create_cusparse_handle(x_base); - cusparseDnVecDescr_t c_descr = __cusparse::create_cusparse_handle(y); - state->set_b_descriptor(b_descr); - state->set_c_descriptor(c_descr); + auto b_handle = __cusparse::create_cusparse_handle(x_base); + auto c_handle = __cusparse::create_cusparse_handle(y); // Get operation type based on matrix format auto a_transpose = __cusparse::get_transpose(a); @@ -65,8 +62,8 @@ void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { // Get buffer size size_t buffer_size; __cusparse::throw_if_error(cusparseSpMV_bufferSize( - state->handle(), a_transpose, &alpha, state->a_descriptor(), - state->b_descriptor(), &beta, state->c_descriptor(), + state->handle(), a_transpose, &alpha, a_handle, + b_handle, &beta, c_handle, detail::cuda_data_type_v>, CUSPARSE_SPMV_ALG_DEFAULT, &buffer_size)); @@ -78,8 +75,8 @@ void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { // Execute SpMV __cusparse::throw_if_error( - cusparseSpMV(state->handle(), a_transpose, &alpha, state->a_descriptor(), - state->b_descriptor(), &beta, state->c_descriptor(), + cusparseSpMV(state->handle(), a_transpose, &alpha, a_handle, + b_handle, &beta, c_handle, detail::cuda_data_type_v>, CUSPARSE_SPMV_ALG_DEFAULT, buffer)); diff --git a/include/spblas/views/matrix_opt_impl.hpp b/include/spblas/views/matrix_opt_impl.hpp index 2bc9dc8..1be6696 100644 --- a/include/spblas/views/matrix_opt_impl.hpp +++ b/include/spblas/views/matrix_opt_impl.hpp @@ -7,6 +7,8 @@ #ifdef SPBLAS_ENABLE_ONEMKL_SYCL #include #include +#elif SPBLAS_ENABLE_CUSPARSE +#include #endif namespace spblas { @@ -24,6 +26,8 @@ class matrix_opt : public view_base { matrix_opt(M matrix) : matrix_(matrix) { #ifdef SPBLAS_ENABLE_ONEMKL_SYCL matrix_handle_ = nullptr; +#elif SPBLAS_ENABLE_CUSPARSE + matrix_handle_ = nullptr; #endif } @@ -36,6 +40,11 @@ class matrix_opt : public view_base { oneapi::mkl::sparse::release_matrix_handle(q, &matrix_handle_, {}).wait(); matrix_handle_ = nullptr; } +#elif SPBLAS_ENABLE_CUSPARSE + if (matrix_handle_) { + cusparseDestroySpMat(matrix_handle_); + matrix_handle_ = nullptr; + } #endif } @@ -89,6 +98,8 @@ class matrix_opt : public view_base { #ifdef SPBLAS_ENABLE_ONEMKL_SYCL oneapi::mkl::sparse::matrix_handle_t matrix_handle_; +#elif SPBLAS_ENABLE_CUSPARSE + cusparseSpMatDescr_t matrix_handle_; #endif };