Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ if (SPBLAS_CPU_BACKEND)
add_example(simple_sptrsv)
add_example(spmm_csc)
add_example(matrix_opt_example)
if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND )
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND )
if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND)

nit

# needs CPU + matrix_opt + operation_info_t to run
add_example(sptrsv_csr) # needs triangular_solve{_inspect} to run
add_example(spmm_csr) # needs multiply{_inspect} to run
endif()
endif()

# GPU examples
Expand Down
54 changes: 54 additions & 0 deletions examples/spmm_csr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include <spblas/spblas.hpp>

#include <fmt/core.h>
#include <fmt/ranges.h>

int main(int argc, char** argv) {
using namespace spblas;
namespace md = spblas::__mdspan;

using T = float;

spblas::index_t m = 10;
spblas::index_t n = 10;
spblas::index_t k = 10;
spblas::index_t nnz_in = 20;

fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n\t### Running Advanced SpMM Example:");
fmt::print("\n\t###");
fmt::print("\n\t### Y = alpha * A * X");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fmt::print("\n\t### Y = alpha * A * X");

fmt::print("\n\t###");
fmt::print("\n\t### with ");
fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, k,
nnz_in);
fmt::print("\n\t### x, a dense matrix, of size ({}, {})", k, n);
fmt::print("\n\t### y, a dense vector, of size ({}, {})", m, n);
fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)",
sizeof(spblas::index_t));
fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n");

auto&& [values, rowptr, colind, shape, nnz] = generate_csr<T>(m, k, nnz_in);

csr_view<T> a(values, rowptr, colind, shape, nnz);
matrix_opt a_opt(a);

std::vector<T> x_values(k * n, 1);
std::vector<T> y_values(m * n, 0);

md::mdspan x(x_values.data(), k, n);
md::mdspan y(y_values.data(), m, n);

// Y = A * X
auto state = multiply_inspect(a_opt, x, y);
multiply(state, a_opt, x, y);

fmt::print("{}\n", spblas::__backend::values(y));

fmt::print("\tExample is completed!\n");

return 0;
}
63 changes: 63 additions & 0 deletions examples/sptrsv_csr.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <spblas/spblas.hpp>

#include <fmt/core.h>
#include <fmt/ranges.h>

int main(int argc, char** argv) {
using namespace spblas;

using T = float;

spblas::index_t m = 100;
spblas::index_t nnz_in = 20;

fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n\t### Running Full SpTRSV Example:");
fmt::print("\n\t###");
fmt::print("\n\t### solve for x: A * x = alpha * b");
fmt::print("\n\t###");
fmt::print("\n\t### with ");
fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, m,
nnz_in);
fmt::print("\n\t### x, a dense vector, of size ({}, {})", m, 1);
fmt::print("\n\t### b, a dense vector, of size ({}, {})", m, 1);
fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)",
sizeof(spblas::index_t));
fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n");

auto&& [values, rowptr, colind, shape, nnz] =
generate_csr<T, spblas::index_t>(m, m, nnz_in);

// scale values of matrix to make the implicit unit diagonal matrix
// be diagonally dominant, so it is solveable
T scale_factor = 1e-3f;
std::transform(values.begin(), values.end(), values.begin(),
[scale_factor](T val) { return scale_factor * val; });

csr_view<T, spblas::index_t> a(values, rowptr, colind, shape, nnz);

matrix_opt a_opt(a);

// Scale every value of `a` by 5 in place.
// scale(5.f, a);

std::vector<T> x(m, 0);
std::vector<T> b(m, 1);

T alpha = 1.2f;
auto b_scaled = scaled(alpha, b);

// solve for x: lower(A) * x = alpha * b
triangular_solve_inspect(a_opt, spblas::upper_triangle_t{},
spblas::implicit_unit_diagonal_t{}, b_scaled, x);

triangular_solve(a_opt, spblas::upper_triangle_t{},
spblas::implicit_unit_diagonal_t{}, b_scaled, x);
Comment on lines +54 to +58
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we would like to change it. Is this PR to add operation_info_t and matrix_opt only? When we introduce the descriptor, then switch to that?


fmt::print("\tExample is completed!\n");

return 0;
}
28 changes: 28 additions & 0 deletions include/spblas/algorithms/multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,46 @@

namespace spblas {

// SpMV variants
template <matrix A, vector B, vector C>
operation_info_t multiply_inspect(A&& a, B&& b, C&& c);

template <matrix A, vector B, vector C>
void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c);

template <matrix A, vector B, vector C>
void multiply(A&& a, B&& b, C&& c);

template <matrix A, vector B, vector C>
void multiply(operation_into_t& info, A&& a, B&& b, C&& c);

// SpMM variants
template <matrix A, matrix B, matrix C>
void multiply(A&& a, B&& b, C&& c);

template <matrix A, matrix B, matrix C>
void multiply(operation_info_t& info, A&& a, B&& b, C&& c);

// SpMM and SpGEMM multiply_inspect variants
template <matrix A, matrix B, matrix C>
operation_info_t multiply_inspect(A&& a, B&& b, C&& c);

template <matrix A, matrix B, matrix C>
void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c);

// SpGEMM variants
template <typename ExecutionPolicy, matrix A, matrix B, matrix C>
operation_info_t multiply_compute(ExecutionPolicy&& policy, A&& a, B&& b,
C&& c);

template <typename ExecutionPolicy, matrix A, matrix B, matrix C>
void multiply_compute(ExecutionPolicy&& policy, operation_info_t& info, A&& a,
B&& b, C&& c);

template <typename ExecutionPolicy, matrix A, matrix B, matrix C>
void multiply_fill(ExecutionPolicy&& policy, operation_info_t& info, A&& a,
B&& b, C&& c);

template <matrix A, matrix B, matrix C>
operation_info_t multiply_compute(A&& a, B&& b, C&& c);

Expand Down
51 changes: 49 additions & 2 deletions include/spblas/algorithms/multiply_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@

namespace spblas {

// SpMV inspect
template <matrix A, vector B, vector C>
operation_info_t multiply_inspect(A&& a, B&& b, C&& c) {
log_trace("");
return operation_info_t{};
}

// SpMV inspect
template <matrix A, vector B, vector C>
operation_info_t multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
}

// C = AB
// SpMV
template <matrix A, vector B, vector C>
Expand All @@ -39,6 +52,15 @@ void multiply(A&& a, B&& b, C&& c) {
});
}

// C = AB
// SpMV with info input
template <matrix A, vector B, vector C>
requires(__backend::lookupable<B> && __backend::lookupable<C>)
void multiply(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
multiply(std::forward<A>(a), std::forward<B>(b), std::forward<C>(c));
}

// C = AB
// SpMM
template <matrix A, matrix B, matrix C>
Expand All @@ -52,37 +74,61 @@ void multiply(A&& a, B&& b, C&& c) {
"multiply: matrix dimensions are incompatible.");
}

// initializes c to zero so we can use += everywhere
__backend::for_each(c, [](auto&& e) {
auto&& [_, v] = e;
v = 0;
});

// traverses elements of a and performs appropriate
// multiplication with B rows
__backend::for_each(a, [&](auto&& e) {
auto&& [idx, a_v] = e;
auto&& [i, k] = idx;
for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) {
for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { // b_row
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) { // b_row
for (std::size_t j = 0; j < __backend::shape(b)[1]; j++) {

maybe not needed? or move it to before for?

__backend::lookup(c, i, j) += a_v * __backend::lookup(b, k, j);
}
});
}

// C = AB
// SpMM with info
template <matrix A, matrix B, matrix C>
requires(__backend::lookupable<B> && __backend::lookupable<C>)
void multiply(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
multiply(std::forward<A>(a), std::forward<B>(b), std::forward<C>(c));
}

// C = AB
// SpMM or SpGEMM multiply_inspect variants end up here
template <matrix A, matrix B, matrix C>
operation_info_t multiply_inspect(A&& a, B&& b, C&& c) {
log_trace("");
return operation_info_t{};
}

// C = AB
// SpMM or SpGEMM multiply_inspect variants end up here
template <matrix A, matrix B, matrix C>
void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c){};
void multiply_inspect(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
};

// C = AB
// SpGEMM compute stage with CSR output
template <matrix A, matrix B, matrix C>
requires(__backend::row_iterable<A> && __backend::row_iterable<B> &&
__detail::is_csr_view_v<C>)
void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) {
log_trace("");
auto new_info = multiply_compute(std::forward<A>(a), std::forward<B>(b),
std::forward<C>(c));
info.update_impl_(new_info.result_shape(), new_info.result_nnz());
}

// C = AB
// SpGEMM compute stage with CSC output
template <matrix A, matrix B, matrix C>
requires(__backend::column_iterable<A> && __backend::column_iterable<B> &&
__detail::is_csc_view_v<C>)
Expand All @@ -93,6 +139,7 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c) {
}

// C = AB
// SpGEMM fill stage with CSR or CSC output
template <matrix A, matrix B, matrix C>
void multiply_fill(operation_info_t info, A&& a, B&& b, C&& c) {
log_trace("");
Expand Down
13 changes: 8 additions & 5 deletions include/spblas/algorithms/triangular_solve.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
#include <spblas/concepts.hpp>
#include <spblas/detail/operation_info_t.hpp>

template <class ExecutionPolicy, in - matrix InMat, class Triangle,
class DiagonalStorage, in - vector InVec, out - vector OutVec>
void triangular_matrix_vector_solve(ExecutionPolicy&& exec, InMat A, Triangle t,
DiagonalStorage d, InVec b, OutVec x);

namespace spblas {

template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle uplo,
DiagonalStorage diag, B&& b, X&& x);

template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
operation_info_t triangular_solve_inspect(A&& a, Triangle uplo,
DiagonalStorage diag, B&& b, X&& x);

template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
void triangular_solve(A&& a, Triangle uplo, DiagonalStorage diag, B&& b, X&& x);

Expand Down
44 changes: 44 additions & 0 deletions include/spblas/algorithms/triangular_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,41 @@

namespace spblas {

// X = inv(A) B
// SpTRSV inspect stage
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
requires(__backend::row_iterable<A> && __backend::lookupable<B> &&
__backend::lookupable<X>)
operation_info_t triangular_solve_inspect(A&& a, Triangle t, DiagonalStorage d,
B&& b, X&& x) {
log_trace("");
static_assert(std::is_same_v<Triangle, upper_triangle_t> ||
std::is_same_v<Triangle, lower_triangle_t>);
assert(__backend::shape(a)[0] == __backend::shape(a)[1]);

return operation_info_t{};
}

// X = inv(A) B
// SpTRSV inspect stage
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
requires(__backend::row_iterable<A> && __backend::lookupable<B> &&
__backend::lookupable<X>)
void triangular_solve_inspect(operation_info_t& info, A&& a, Triangle t,
DiagonalStorage d, B&& b, X&& x) {
log_trace("");
static_assert(std::is_same_v<Triangle, upper_triangle_t> ||
std::is_same_v<Triangle, lower_triangle_t>);
assert(__backend::shape(a)[0] == __backend::shape(a)[1]);
}

// X = inv(A) B
// SpTRSV solve stage
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
requires(__backend::row_iterable<A> && __backend::lookupable<B> &&
__backend::lookupable<X>)
void triangular_solve(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) {
log_trace("");
static_assert(std::is_same_v<Triangle, upper_triangle_t> ||
std::is_same_v<Triangle, lower_triangle_t>);
assert(__backend::shape(a)[0] == __backend::shape(a)[1]);
Expand Down Expand Up @@ -62,4 +93,17 @@ void triangular_solve(A&& a, Triangle t, DiagonalStorage d, B&& b, X&& x) {
}
}

// X = inv(A) B
// SpTRSV solve stage with info
template <matrix A, class Triangle, class DiagonalStorage, vector B, vector X>
requires(__backend::row_iterable<A> && __backend::lookupable<B> &&
__backend::lookupable<X>)
void triangular_solve(operation_info_t& info, A&& a, Triangle t,
DiagonalStorage d, B&& b, X&& x) {
log_trace("");
triangular_solve(std::forward<A>(a), std::forward<Triangle>(t),
std::forward<DiagonalStorage>(d), std::forward<B>(b),
std::forward<X>(x));
}

} // namespace spblas
Loading
Loading