-
Notifications
You must be signed in to change notification settings - Fork 10
[ONEMKL_SYCL][REF] Add operation_info_t and matrix_opt to more cases and add two examples to cover it #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,54 @@ | ||||
| #include <spblas/spblas.hpp> | ||||
|
|
||||
| #include <fmt/core.h> | ||||
| #include <fmt/ranges.h> | ||||
|
|
||||
| int main(int argc, char** argv) { | ||||
| using namespace spblas; | ||||
| namespace md = spblas::__mdspan; | ||||
|
|
||||
| using T = float; | ||||
|
|
||||
| spblas::index_t m = 10; | ||||
| spblas::index_t n = 10; | ||||
| spblas::index_t k = 10; | ||||
| spblas::index_t nnz_in = 20; | ||||
|
|
||||
| fmt::print("\n\t###########################################################" | ||||
| "######################"); | ||||
| fmt::print("\n\t### Running Advanced SpMM Example:"); | ||||
| fmt::print("\n\t###"); | ||||
| fmt::print("\n\t### Y = alpha * A * X"); | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| fmt::print("\n\t###"); | ||||
| fmt::print("\n\t### with "); | ||||
| fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, k, | ||||
| nnz_in); | ||||
| fmt::print("\n\t### x, a dense matrix, of size ({}, {})", k, n); | ||||
| fmt::print("\n\t### y, a dense vector, of size ({}, {})", m, n); | ||||
| fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)", | ||||
| sizeof(spblas::index_t)); | ||||
| fmt::print("\n\t###########################################################" | ||||
| "######################"); | ||||
| fmt::print("\n"); | ||||
|
|
||||
| auto&& [values, rowptr, colind, shape, nnz] = generate_csr<T>(m, k, nnz_in); | ||||
|
|
||||
| csr_view<T> a(values, rowptr, colind, shape, nnz); | ||||
| matrix_opt a_opt(a); | ||||
|
|
||||
| std::vector<T> x_values(k * n, 1); | ||||
| std::vector<T> y_values(m * n, 0); | ||||
|
|
||||
| md::mdspan x(x_values.data(), k, n); | ||||
| md::mdspan y(y_values.data(), m, n); | ||||
|
|
||||
| // Y = A * X | ||||
| auto state = multiply_inspect(a_opt, x, y); | ||||
| multiply(state, a_opt, x, y); | ||||
|
|
||||
| fmt::print("{}\n", spblas::__backend::values(y)); | ||||
|
|
||||
| fmt::print("\tExample is completed!\n"); | ||||
|
|
||||
| return 0; | ||||
| } | ||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| #include <spblas/spblas.hpp> | ||
|
|
||
| #include <fmt/core.h> | ||
| #include <fmt/ranges.h> | ||
|
|
||
| int main(int argc, char** argv) { | ||
| using namespace spblas; | ||
|
|
||
| using T = float; | ||
|
|
||
| spblas::index_t m = 100; | ||
| spblas::index_t nnz_in = 20; | ||
|
|
||
| fmt::print("\n\t###########################################################" | ||
| "######################"); | ||
| fmt::print("\n\t### Running Full SpTRSV Example:"); | ||
| fmt::print("\n\t###"); | ||
| fmt::print("\n\t### solve for x: A * x = alpha * b"); | ||
| fmt::print("\n\t###"); | ||
| fmt::print("\n\t### with "); | ||
| fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, m, | ||
| nnz_in); | ||
| fmt::print("\n\t### x, a dense vector, of size ({}, {})", m, 1); | ||
| fmt::print("\n\t### b, a dense vector, of size ({}, {})", m, 1); | ||
| fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)", | ||
| sizeof(spblas::index_t)); | ||
| fmt::print("\n\t###########################################################" | ||
| "######################"); | ||
| fmt::print("\n"); | ||
|
|
||
| auto&& [values, rowptr, colind, shape, nnz] = | ||
| generate_csr<T, spblas::index_t>(m, m, nnz_in); | ||
|
|
||
| // scale values of matrix to make the implicit unit diagonal matrix | ||
| // be diagonally dominant, so it is solveable | ||
| T scale_factor = 1e-3f; | ||
| std::transform(values.begin(), values.end(), values.begin(), | ||
| [scale_factor](T val) { return scale_factor * val; }); | ||
|
|
||
| csr_view<T, spblas::index_t> a(values, rowptr, colind, shape, nnz); | ||
|
|
||
| matrix_opt a_opt(a); | ||
|
|
||
| // Scale every value of `a` by 5 in place. | ||
| // scale(5.f, a); | ||
|
|
||
| std::vector<T> x(m, 0); | ||
| std::vector<T> b(m, 1); | ||
|
|
||
| T alpha = 1.2f; | ||
| auto b_scaled = scaled(alpha, b); | ||
|
|
||
| // solve for x: lower(A) * x = alpha * b | ||
| triangular_solve_inspect(a_opt, spblas::upper_triangle_t{}, | ||
| spblas::implicit_unit_diagonal_t{}, b_scaled, x); | ||
|
|
||
| triangular_solve(a_opt, spblas::upper_triangle_t{}, | ||
| spblas::implicit_unit_diagonal_t{}, b_scaled, x); | ||
|
Comment on lines
+54
to
+58
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought we would like to change it. Is this PR to add |
||
|
|
||
| fmt::print("\tExample is completed!\n"); | ||
|
|
||
| return 0; | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -15,6 +15,19 @@ | |||||
|
|
||||||
| namespace spblas { | ||||||
|
|
||||||
| // SpMV inspect | ||||||
| template <matrix A, vector B, vector C> | ||||||
| operation_info_t multiply_inspect(A&& a, B&& b, C&& c) { | ||||||
| log_trace(""); | ||||||
| return operation_info_t{}; | ||||||
| } | ||||||
|
|
||||||
| // SpMV inspect | ||||||
| template <matrix A, vector B, vector C> | ||||||
| operation_info_t multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) { | ||||||
| log_trace(""); | ||||||
| } | ||||||
|
|
||||||
| // C = AB | ||||||
| // SpMV | ||||||
| template <matrix A, vector B, vector C> | ||||||
|
|
@@ -39,6 +52,15 @@ void multiply(A&& a, B&& b, C&& c) { | |||||
| }); | ||||||
| } | ||||||
|
|
||||||
| // C = AB | ||||||
| // SpMV with info input | ||||||
| template <matrix A, vector B, vector C> | ||||||
| requires(__backend::lookupable<B> && __backend::lookupable<C>) | ||||||
| void multiply(operation_info_t& info, A&& a, B&& b, C&& c) { | ||||||
| log_trace(""); | ||||||
| multiply(std::forward<A>(a), std::forward<B>(b), std::forward<C>(c)); | ||||||
| } | ||||||
|
|
||||||
| // C = AB | ||||||
| // SpMM | ||||||
| template <matrix A, matrix B, matrix C> | ||||||
|
|
@@ -52,37 +74,61 @@ void multiply(A&& a, B&& b, C&& c) { | |||||
| "multiply: matrix dimensions are incompatible."); | ||||||
| } | ||||||
|
|
||||||
| // initializes c to zero so we can use += everywhere | ||||||
| __backend::for_each(c, [](auto&& e) { | ||||||
| auto&& [_, v] = e; | ||||||
| v = 0; | ||||||
| }); | ||||||
|
|
||||||
| // traverses elements of a and performs appropriate | ||||||
| // multiplication with B rows | ||||||
| __backend::for_each(a, [&](auto&& e) { | ||||||
| auto&& [idx, a_v] = e; | ||||||
| auto&& [i, k] = idx; | ||||||
| for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { | ||||||
| for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { // b_row | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
maybe not needed? or move it to before |
||||||
| __backend::lookup(c, i, j) += a_v * __backend::lookup(b, k, j); | ||||||
| } | ||||||
| }); | ||||||
| } | ||||||
|
|
||||||
| // C = AB | ||||||
| // SpMM with info | ||||||
| template <matrix A, matrix B, matrix C> | ||||||
| requires(__backend::lookupable<B> && __backend::lookupable<C>) | ||||||
| void multiply(operation_info_t& info, A&& a, B&& b, C&& c) { | ||||||
| log_trace(""); | ||||||
| multiply(std::forward<A>(a), std::forward<B>(b), std::forward<C>(c)); | ||||||
| } | ||||||
|
|
||||||
| // C = AB | ||||||
| // SpMM or SpGEMM multiply_inspect variants end up here | ||||||
| template <matrix A, matrix B, matrix C> | ||||||
| operation_info_t multiply_inspect(A&& a, B&& b, C&& c) { | ||||||
| log_trace(""); | ||||||
| return operation_info_t{}; | ||||||
| } | ||||||
|
|
||||||
| // C = AB | ||||||
| // SpMM or SpGEMM multiply_inspect variants end up here | ||||||
| template <matrix A, matrix B, matrix C> | ||||||
| void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c){}; | ||||||
| void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) { | ||||||
| log_trace(""); | ||||||
| }; | ||||||
|
|
||||||
| // C = AB | ||||||
| // SpGEMM compute stage with CSR output | ||||||
| template <matrix A, matrix B, matrix C> | ||||||
| requires(__backend::row_iterable<A> && __backend::row_iterable<B> && | ||||||
| __detail::is_csr_view_v<C>) | ||||||
| void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) { | ||||||
| log_trace(""); | ||||||
| auto new_info = multiply_compute(std::forward<A>(a), std::forward<B>(b), | ||||||
| std::forward<C>(c)); | ||||||
| info.update_impl_(new_info.result_shape(), new_info.result_nnz()); | ||||||
| } | ||||||
|
|
||||||
| // C = AB | ||||||
| // SpGEMM compute stage with CSC output | ||||||
| template <matrix A, matrix B, matrix C> | ||||||
| requires(__backend::column_iterable<A> && __backend::column_iterable<B> && | ||||||
| __detail::is_csc_view_v<C>) | ||||||
|
|
@@ -93,6 +139,7 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) { | |||||
| } | ||||||
|
|
||||||
| // C = AB | ||||||
| // SpGEMM fill stage with CSR or CSC output | ||||||
| template <matrix A, matrix B, matrix C> | ||||||
| void multiply_fill(operation_info_t info, A&& a, B&& b, C&& c) { | ||||||
| log_trace(""); | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit