diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index fcf3a82..ead25c4 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -11,6 +11,11 @@ if (SPBLAS_CPU_BACKEND) add_example(simple_sptrsv) add_example(spmm_csc) add_example(matrix_opt_example) + if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND ) + # needs CPU + matrix_opt + operation_info_t to run + add_example(sptrsv_csr) # needs triangular_solve{_inspect} to run + add_example(spmm_csr) # needs multiply{_inspect} to run + endif() endif() # GPU examples diff --git a/examples/spmm_csr.cpp b/examples/spmm_csr.cpp new file mode 100644 index 0000000..57ceeb9 --- /dev/null +++ b/examples/spmm_csr.cpp @@ -0,0 +1,54 @@ +#include + +#include +#include + +int main(int argc, char** argv) { + using namespace spblas; + namespace md = spblas::__mdspan; + + using T = float; + + spblas::index_t m = 10; + spblas::index_t n = 10; + spblas::index_t k = 10; + spblas::index_t nnz_in = 20; + + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n\t### Running Advanced SpMM Example:"); + fmt::print("\n\t###"); + fmt::print("\n\t### Y = alpha * A * X"); + fmt::print("\n\t###"); + fmt::print("\n\t### with "); + fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, k, + nnz_in); + fmt::print("\n\t### x, a dense matrix, of size ({}, {})", k, n); + fmt::print("\n\t### y, a dense vector, of size ({}, {})", m, n); + fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)", + sizeof(spblas::index_t)); + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n"); + + auto&& [values, rowptr, colind, shape, nnz] = generate_csr(m, k, nnz_in); + + csr_view a(values, rowptr, colind, shape, nnz); + matrix_opt a_opt(a); + + std::vector x_values(k * n, 1); + std::vector y_values(m * n, 0); + + md::mdspan x(x_values.data(), k, n); + md::mdspan y(y_values.data(), m, n); + + // Y = A * X + auto state = multiply_inspect(a_opt, x, y); + multiply(state, a_opt, x, y); + + fmt::print("{}\n", spblas::__backend::values(y)); + + fmt::print("\tExample is completed!\n"); + + return 0; +} diff --git a/examples/sptrsv_csr.cpp b/examples/sptrsv_csr.cpp new file mode 100644 index 0000000..ab74c92 --- /dev/null +++ b/examples/sptrsv_csr.cpp @@ -0,0 +1,63 @@ +#include + +#include +#include + +int main(int argc, char** argv) { + using namespace spblas; + + using T = float; + + spblas::index_t m = 100; + spblas::index_t nnz_in = 20; + + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n\t### Running Full SpTRSV Example:"); + fmt::print("\n\t###"); + fmt::print("\n\t### solve for x: A * x = alpha * b"); + fmt::print("\n\t###"); + fmt::print("\n\t### with "); + fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, m, + nnz_in); + fmt::print("\n\t### x, a dense vector, of size ({}, {})", m, 1); + fmt::print("\n\t### b, a dense vector, of size ({}, {})", m, 1); + fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)", + sizeof(spblas::index_t)); + fmt::print("\n\t###########################################################" + "######################"); + fmt::print("\n"); + + auto&& [values, rowptr, colind, shape, nnz] = + generate_csr(m, m, nnz_in); + + // scale values of matrix to make the implicit unit diagonal matrix + // be diagonally dominant, so it is solveable + T scale_factor = 1e-3f; + std::transform(values.begin(), values.end(), values.begin(), + [scale_factor](T val) { return scale_factor * val; }); + + csr_view a(values, rowptr, colind, shape, nnz); + + matrix_opt a_opt(a); + + // Scale every value of `a` by 5 in place. + // scale(5.f, a); + + std::vector x(m, 0); + std::vector b(m, 1); + + T alpha = 1.2f; + auto b_scaled = scaled(alpha, b); + + // solve for x: lower(A) * x = alpha * b + triangular_solve_inspect(a_opt, spblas::upper_triangle_t{}, + spblas::implicit_unit_diagonal_t{}, b_scaled, x); + + triangular_solve(a_opt, spblas::upper_triangle_t{}, + spblas::implicit_unit_diagonal_t{}, b_scaled, x); + + fmt::print("\tExample is completed!\n"); + + return 0; +} diff --git a/include/spblas/algorithms/multiply.hpp b/include/spblas/algorithms/multiply.hpp index f15748e..be4b255 100644 --- a/include/spblas/algorithms/multiply.hpp +++ b/include/spblas/algorithms/multiply.hpp @@ -5,18 +5,46 @@ namespace spblas { +// SpMV variants +template +operation_info_t multiply_inspect(A&& a, B&& b, C&& c); + +template +void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c); + template void multiply(A&& a, B&& b, C&& c); +template +void multiply(operation_into_t& info, A&& a, B&& b, C&& c); + +// SpMM variants template void multiply(A&& a, B&& b, C&& c); +template +void multiply(operation_info_t& info, A&& a, B&& b, C&& c); + +// SpMM and SpGEMM multiply_inspect variants template operation_info_t multiply_inspect(A&& a, B&& b, C&& c); template void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c); +// SpGEMM variants +template +operation_info_t multiply_compute(ExecutionPolicy&& policy, A&& a, B&& b, + C&& c); + +template +void multiply_compute(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + B&& b, C&& c); + +template +void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + B&& b, C&& c); + template operation_info_t multiply_compute(A&& a, B&& b, C&& c); diff --git a/include/spblas/algorithms/multiply_impl.hpp b/include/spblas/algorithms/multiply_impl.hpp index d56da6b..856c97b 100644 --- a/include/spblas/algorithms/multiply_impl.hpp +++ b/include/spblas/algorithms/multiply_impl.hpp @@ -15,6 +15,19 @@ namespace spblas { +// SpMV inspect +template +operation_info_t multiply_inspect(A&& a, B&& b, C&& c) { + log_trace(""); + return operation_info_t{}; +} + +// SpMV inspect +template +operation_info_t multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) { + log_trace(""); +} + // C = AB // SpMV template @@ -39,6 +52,15 @@ void multiply(A&& a, B&& b, C&& c) { }); } +// C = AB +// SpMV with info input +template + requires(__backend::lookupable && __backend::lookupable) +void multiply(operation_info_t& info, A&& a, B&& b, C&& c) { + log_trace(""); + multiply(std::forward(a), std::forward(b), std::forward(c)); +} + // C = AB // SpMM template @@ -52,37 +74,61 @@ void multiply(A&& a, B&& b, C&& c) { "multiply: matrix dimensions are incompatible."); } + // initializes c to zero so we can use += everywhere __backend::for_each(c, [](auto&& e) { auto&& [_, v] = e; v = 0; }); + // traverses elements of a and performs appropriate + // multiplication with B rows __backend::for_each(a, [&](auto&& e) { auto&& [idx, a_v] = e; auto&& [i, k] = idx; - for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { + for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { // b_row __backend::lookup(c, i, j) += a_v * __backend::lookup(b, k, j); } }); } +// C = AB +// SpMM with info +template + requires(__backend::lookupable && __backend::lookupable) +void multiply(operation_info_t& info, A&& a, B&& b, C&& c) { + log_trace(""); + multiply(std::forward(a), std::forward(b), std::forward(c)); +} + +// C = AB +// SpMM or SpGEMM multiply_inspect variants end up here template operation_info_t multiply_inspect(A&& a, B&& b, C&& c) { + log_trace(""); return operation_info_t{}; } +// C = AB +// SpMM or SpGEMM multiply_inspect variants end up here template -void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c){}; +void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) { + log_trace(""); +}; +// C = AB +// SpGEMM compute stage with CSR output template requires(__backend::row_iterable && __backend::row_iterable && __detail::is_csr_view_v) void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) { + log_trace(""); auto new_info = multiply_compute(std::forward(a), std::forward(b), std::forward(c)); info.update_impl_(new_info.result_shape(), new_info.result_nnz()); } +// C = AB +// SpGEMM compute stage with CSC output template requires(__backend::column_iterable && __backend::column_iterable && __detail::is_csc_view_v) @@ -93,6 +139,7 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) { } // C = AB +// SpGEMM fill stage with CSR or CSC output template void multiply_fill(operation_info_t info, A&& a, B&& b, C&& c) { log_trace(""); diff --git a/include/spblas/algorithms/triangular_solve.hpp b/include/spblas/algorithms/triangular_solve.hpp index 5bf1d88..88821df 100644 --- a/include/spblas/algorithms/triangular_solve.hpp +++ b/include/spblas/algorithms/triangular_solve.hpp @@ -3,13 +3,16 @@ #include #include -template -void triangular_matrix_vector_solve(ExecutionPolicy&& exec, InMat A, Triangle t, - DiagonalStorage d, InVec b, OutVec x); - namespace spblas { +template +void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x); + +template +operation_info_t triangular_solve_inspect(A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x); + template void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, X&& x); diff --git a/include/spblas/algorithms/triangular_solve_impl.hpp b/include/spblas/algorithms/triangular_solve_impl.hpp index 52be891..4b4f0f3 100644 --- a/include/spblas/algorithms/triangular_solve_impl.hpp +++ b/include/spblas/algorithms/triangular_solve_impl.hpp @@ -8,10 +8,41 @@ namespace spblas { +// X = inv(A) B +// SpTRSV inspect stage +template + requires(__backend::row_iterable && __backend::lookupable && + __backend::lookupable) +operation_info_t triangular_solve_inspect(A&& a, Triangle t, DiagonalStorage d, + B&& b, X&& x) { + log_trace(""); + static_assert(std::is_same_v || + std::is_same_v); + assert(__backend::shape(a)[0] == __backend::shape(a)[1]); + + return operation_info_t{}; +} + +// X = inv(A) B +// SpTRSV inspect stage +template + requires(__backend::row_iterable && __backend::lookupable && + __backend::lookupable) +void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle t, + DiagonalStorage d, B&& b, X&& x) { + log_trace(""); + static_assert(std::is_same_v || + std::is_same_v); + assert(__backend::shape(a)[0] == __backend::shape(a)[1]); +} + +// X = inv(A) B +// SpTRSV solve stage template requires(__backend::row_iterable && __backend::lookupable && __backend::lookupable) void triangular_solve(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { + log_trace(""); static_assert(std::is_same_v || std::is_same_v); assert(__backend::shape(a)[0] == __backend::shape(a)[1]); @@ -62,4 +93,17 @@ void triangular_solve(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) { } } +// X = inv(A) B +// SpTRSV solve stage with info +template + requires(__backend::row_iterable && __backend::lookupable && + __backend::lookupable) +void triangular_solve(operation_info_t& info, A&& a, Triangle t, + DiagonalStorage d, B&& b, X&& x) { + log_trace(""); + triangular_solve(std::forward(a), std::forward(t), + std::forward(d), std::forward(b), + std::forward(x)); +} + } // namespace spblas diff --git a/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp b/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp index 2413766..1e72b90 100644 --- a/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp +++ b/include/spblas/vendor/onemkl_sycl/detail/create_matrix_handle.hpp @@ -18,8 +18,14 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q, oneapi::mkl::sparse::init_matrix_handle(&handle); oneapi::mkl::sparse::set_csr_data( - q, handle, m.shape()[0], m.shape()[1], oneapi::mkl::index_base::zero, - m.rowptr().data(), m.colind().data(), m.values().data()) + q, handle, m.shape()[0], m.shape()[1], +#if defined(__INTEL_MKL__) && \ + ((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \ + (__INTEL_MKL__ > 2025)) + m.size(), // nnz added in 2025.3, and without deprecated +#endif + oneapi::mkl::index_base::zero, m.rowptr().data(), m.colind().data(), + m.values().data()) .wait(); return handle; @@ -33,8 +39,14 @@ oneapi::mkl::sparse::matrix_handle_t create_matrix_handle(sycl::queue& q, oneapi::mkl::sparse::init_matrix_handle(&handle); oneapi::mkl::sparse::set_csr_data( - q, handle, m.shape()[1], m.shape()[0], oneapi::mkl::index_base::zero, - m.colptr().data(), m.rowind().data(), m.values().data()) + q, handle, m.shape()[1], m.shape()[0], +#if defined(__INTEL_MKL__) && \ + ((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \ + (__INTEL_MKL__ > 2025)) + m.size(), // nnz added in 2025.3, and without deprecated +#endif + oneapi::mkl::index_base::zero, m.colptr().data(), m.rowind().data(), + m.values().data()) .wait(); return handle; diff --git a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp index 4ee63c9..11e265a 100644 --- a/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spgemm_impl.hpp @@ -29,6 +29,9 @@ namespace spblas { +// +// multiply_compute -- csr/csc * csr/csc -> csr with ExecutionPolicy +// template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && @@ -68,6 +71,11 @@ operation_info_t oneapi::mkl::sparse::set_csr_data( q, c_handle, __backend::shape(c)[0], __backend::shape(c)[1], +#if defined(__INTEL_MKL__) && \ + ((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \ + (__INTEL_MKL__ > 2025)) + __backend::size(c), // nnz added in 2025.3, and without deprecated +#endif oneapi::mkl::index_base::zero, c_rowptr, (I*) nullptr, (T*) nullptr) .wait(); @@ -117,8 +125,37 @@ operation_info_t __mkl::operation_state_t{__detail::has_matrix_opt(a) ? nullptr : a_handle, __detail::has_matrix_opt(b) ? nullptr : b_handle, c_handle, nullptr, descr, (void*) c_rowptr, q}}; -} +} // multiply_compute + +// +// multiply_compute -- csr/csc * csr/csc -> csr with ExecutionPolicy +// +template + requires(__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csr_view_v +void multiply_compute(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + B&& b, C&& c) { + log_trace(""); + + auto tmp_info = multiply_compute(std::forward(policy), + std::forward(a), std::forward(b), + std::forward(c)); + // fill the normal bucket of state stuf based on creating model for now. + info.update_impl_(tmp_info.result_shape(), tmp_info.result_nnz()); + info.state_.a_handle = tmp_info.state_.a_handle; + info.state_.b_handle = tmp_info.state_.b_handle; + info.state_.c_handle = tmp_info.state_.c_handle; + info.state_.descr = tmp_info.state_.descr; + info.state_.c_rowptr = tmp_info.state_.c_rowptr; + info.state_.q = tmp_info.state_.q; + +} // multiply_compute + +// +// multiply_fill -- csr/csc * csr/csc -> csr with ExecutionPolicy +// template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && @@ -155,6 +192,11 @@ void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a, auto ev_setC = oneapi::mkl::sparse::set_csr_data( q, c_handle, __backend::shape(c)[0], __backend::shape(c)[1], +#if defined(__INTEL_MKL__) && \ + ((__INTEL_MKL__ == 2025) && (__INTEL_MKL_MINOR__ == 3) || \ + (__INTEL_MKL__ > 2025)) + __backend::size(c), // nnz added in 2025.3, and without deprecated +#endif oneapi::mkl::index_base::zero, c_rowptr, c.colind().data(), c.values().data()); @@ -186,6 +228,17 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) { std::forward(c)); } +template + requires(__detail::has_csr_base || __detail::has_csc_base) && + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::is_csr_view_v +void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) { + log_trace(""); + return multiply_compute(mkl::par, std::forward(info), + std::forward(a), std::forward(b), + std::forward(c)); +} + template requires(__detail::has_csr_base || __detail::has_csc_base) && (__detail::has_csr_base || __detail::has_csc_base) && diff --git a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp index 82d1415..441ce28 100644 --- a/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmm_impl.hpp @@ -37,9 +37,65 @@ template __mdspan::layout_right> && std::is_same_v::layout_type, __mdspan::layout_right>) -void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { +void multiply_inspect(ExecutionPolicy&& policy, operation_info_t& info, A&& a, + X&& x, Y&& y) { + log_trace(""); + if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense matrices."); + } + + if (__detail::has_matrix_opt(a)) { + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); + + auto a_handle = __mkl::get_matrix_handle(q, a); + auto a_transpose = __mkl::get_transpose(a); + + auto x_base = __detail::get_ultimate_base(x); + + oneapi::mkl::sparse::optimize_gemm( + q, oneapi::mkl::layout::row_major, a_transpose, + oneapi::mkl::transpose::nontrans, a_handle, + static_cast(x_base.extent(1))) + .wait(); + } else { + // do nothing, since it would be immediately discarded + log_info( + "No work done, since no matrix_opt to store optimized results into!"); + } +} // multiply_inspect + +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +operation_info_t multiply_inspect(ExecutionPolicy&& policy, A&& a, X&& x, + Y&& y) { + log_trace(""); + operation_info_t info{}; + + multiply_inspect(std::forward(policy), info, + std::forward(a), std::forward(x), std::forward(y)); + + return info; +} + +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +void multiply(ExecutionPolicy&& policy, operation_info_t& info, A&& a, X&& x, + Y&& y) { log_trace(""); - auto x_base = __detail::get_ultimate_base(x); if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { throw std::runtime_error( @@ -55,6 +111,8 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { auto a_handle = __mkl::get_matrix_handle(q, a); auto a_transpose = __mkl::get_transpose(a); + auto x_base = __detail::get_ultimate_base(x); + oneapi::mkl::sparse::gemm(q, oneapi::mkl::layout::row_major, a_transpose, oneapi::mkl::transpose::nontrans, alpha, a_handle, x_base.data_handle(), x_base.extent(1), @@ -66,6 +124,64 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { } } +// +// multiply_inspect - CSR/CSC with row major dense matrix rhs without execution +// policy +// +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +operation_info_t multiply_inspect(A&& a, X&& x, Y&& y) { + log_trace(""); + auto info = multiply_inspect(mkl::par, std::forward(a), std::forward(x), + std::forward(y)); + return info; +} + +// +// multiply_inspect - CSR/CSC with row major dense matrix rhs without execution +// policy +// +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +void multiply_inspect(operation_info_t& info, A&& a, X&& x, Y&& y) { + log_trace(""); + multiply_inspect(mkl::par, info, std::forward(a), std::forward(x), + std::forward(y)); +} + +// +// multiply - CSR/CSC with row major dense matrix rhs without execution policy +// +template + requires( + (__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_mdspan_matrix_base && __detail::is_matrix_mdspan_v && + std::is_same_v::layout_type, + __mdspan::layout_right> && + std::is_same_v::layout_type, + __mdspan::layout_right>) +void multiply(operation_info_t& info, A&& a, X&& x, Y&& y) { + log_trace(""); + multiply(mkl::par, info, std::forward(a), std::forward(x), + std::forward(y)); +} + +// +// multiply - CSR/CSC with row major dense matrix rhs without execution policy +// or state object +// template requires( (__detail::has_csr_base || __detail::has_csc_base) && @@ -75,7 +191,9 @@ template std::is_same_v::layout_type, __mdspan::layout_right>) void multiply(A&& a, X&& x, Y&& y) { - multiply(mkl::par, std::forward(a), std::forward(x), + log_trace(""); + operation_info_t info{}; + multiply(mkl::par, info, std::forward(a), std::forward(x), std::forward(y)); } diff --git a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp index c6b73c1..6377397 100644 --- a/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/spmv_impl.hpp @@ -28,6 +28,40 @@ namespace spblas { +// +// multiply_inspect with CSR/CSC and single rhs +// +template + requires((__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_contiguous_range_base && + __ranges::contiguous_range) +void multiply_inspect(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { + log_trace(""); + + if (__detail::is_conjugated(x) || __detail::is_conjugated(y)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense vectors."); + } + + if (__detail::has_matrix_opt(a)) { + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); + + auto a_handle = __mkl::get_matrix_handle(q, a); + auto a_transpose = __mkl::get_transpose(a); + + oneapi::mkl::sparse::optimize_gemv(q, a_transpose, a_handle).wait(); + } else { + // do nothing, since it would be trashed immediately after + log_info( + "No work done, since no matrix_opt to store optimized results into!"); + } + +} // multiply_inspect + +// +// multiply with CSR/CSC and single rhs +// template requires((__detail::has_csr_base || __detail::has_csc_base) && __detail::has_contiguous_range_base && @@ -45,7 +79,6 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { tensor_scalar_t alpha = alpha_optional.value_or(1); auto a_data = __detail::get_ultimate_base(a).values().data(); - auto&& q = __mkl::get_queue(policy, a_data); auto a_handle = __mkl::get_matrix_handle(q, a); @@ -60,6 +93,23 @@ void multiply(ExecutionPolicy&& policy, A&& a, X&& x, Y&& y) { } } +// +// multiply_inspect -- CSR/CSC + single rhs vector +// with no ExecutionPolicy +// +template + requires((__detail::has_csr_base || __detail::has_csc_base) && + __detail::has_contiguous_range_base && + __ranges::contiguous_range) +void multiply_inspect(A&& a, X&& x, Y&& y) { + log_trace(""); + multiply_inspect(mkl::par, std::forward(a), std::forward(x), + std::forward(y)); +} + +// +// multiply -- CSR/CSC + single rhs vector +// template requires((__detail::has_csr_base || __detail::has_csc_base) && __detail::has_contiguous_range_base && diff --git a/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp b/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp index 4d9bd05..295ff0b 100644 --- a/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp +++ b/include/spblas/vendor/onemkl_sycl/triangular_solve_impl.hpp @@ -26,40 +26,36 @@ namespace spblas { // lower + conjtrans (D+U)^H -> conjtrans + upper (D+U)^H // -template +// +// CSR triangular solve inspection step +// +template requires __detail::has_csr_base && __detail::has_contiguous_range_base && __ranges::contiguous_range -void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, - X&& x) { +void triangular_solve_inspect(ExecutionPolicy&& policy, A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x) { log_trace(""); static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v || std::is_same_v); - auto a_base = __detail::get_ultimate_base(a); - auto b_base = __detail::get_ultimate_base(b); + if (__detail::is_conjugated(b) || __detail::is_conjugated(x)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense vectors."); + } using T = tensor_scalar_t; using I = tensor_index_t; using O = tensor_offset_t; - auto alpha_optional = __detail::get_scaling_factor(a, b); - T alpha = alpha_optional.value_or(1); - - sycl::queue q(sycl::cpu_selector_v); + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); - oneapi::mkl::sparse::matrix_handle_t a_handle = nullptr; - oneapi::mkl::sparse::init_matrix_handle(&a_handle); - - oneapi::mkl::sparse::set_csr_data( - q, a_handle, __backend::shape(a_base)[0], __backend::shape(a_base)[1], - oneapi::mkl::index_base::zero, a_base.rowptr().data(), - a_base.colind().data(), a_base.values().data()) - .wait(); - - auto op = oneapi::mkl::transpose::nontrans; + auto a_handle = __mkl::get_matrix_handle(q, a); + auto a_op = __mkl::get_transpose(a); auto uplo_val = std::is_same_v @@ -70,12 +66,95 @@ void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, ? oneapi::mkl::diag::nonunit : oneapi::mkl::diag::unit; - oneapi::mkl::sparse::trsv(q, uplo_val, op, diag_val, alpha, a_handle, + oneapi::mkl::sparse::optimize_trsv(q, uplo_val, a_op, diag_val, a_handle) + .wait(); + + if (!__detail::has_matrix_opt(a)) { + oneapi::mkl::sparse::release_matrix_handle(q, &a_handle).wait(); + } +} + +// +// CSR triangular solve execution step +// +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve(ExecutionPolicy&& policy, A&& a, Triangle uplo, + DiagonalStorage diag, B&& b, X&& x) { + log_trace(""); + static_assert(std::is_same_v || + std::is_same_v); + static_assert(std::is_same_v || + std::is_same_v); + + if (__detail::is_conjugated(b) || __detail::is_conjugated(x)) { + throw std::runtime_error( + "oneMKL SYCL backend does not support conjugated dense vectors."); + } + + using T = tensor_scalar_t; + using I = tensor_index_t; + using O = tensor_offset_t; + + auto alpha_optional = __detail::get_scaling_factor(a, b); + tensor_scalar_t alpha = alpha_optional.value_or(1); + + auto a_data = __detail::get_ultimate_base(a).values().data(); + auto&& q = __mkl::get_queue(policy, a_data); + + auto a_handle = __mkl::get_matrix_handle(q, a); + auto a_op = __mkl::get_transpose(a); + + auto uplo_val = std::is_same_v + ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + auto diag_val = std::is_same_v + ? oneapi::mkl::diag::nonunit + : oneapi::mkl::diag::unit; + + auto b_base = __detail::get_ultimate_base(b); + + oneapi::mkl::sparse::trsv(q, uplo_val, a_op, diag_val, alpha, a_handle, __ranges::data(b_base), __ranges::data(x)) .wait(); - oneapi::mkl::sparse::release_matrix_handle(q, &a_handle).wait(); + if (!__detail::has_matrix_opt(a)) { + oneapi::mkl::sparse::release_matrix_handle(q, &a_handle).wait(); + } } // triangular_solve +// +// CSR triangular_solve_inspect with no exception policy +// +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve_inspect(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, + X&& x) { + triangular_solve_inspect(mkl::par, std::forward(a), + std::forward(uplo), + std::forward(diag), + std::forward(b), std::forward(x)); +} // triangular_solve_inspect + +// +// CSR triangular_solve with no exception policy +// +template + requires __detail::has_csr_base && + __detail::has_contiguous_range_base && + __ranges::contiguous_range +void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, + X&& x) { + triangular_solve(mkl::par, std::forward(a), std::forward(uplo), + std::forward(diag), std::forward(b), + std::forward(x)); +} // triangular_solve + } // namespace spblas