Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/cusparse/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ function(add_cuda_example example_name)
endfunction()

add_cuda_example(cusparse_simple_spmv)
add_cuda_example(cusparse_simple_spmm)
119 changes: 119 additions & 0 deletions examples/cusparse/cusparse_simple_spmm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include <iostream>
#include <spblas/spblas.hpp>

#include <cuda_runtime.h>

#include "util.hpp"

#include <fmt/core.h>
#include <fmt/ranges.h>

int main(int argc, char** argv) {
namespace md = spblas::__mdspan;

using value_t = float;
using index_t = spblas::index_t;
using offset_t = spblas::offset_t;

spblas::index_t m = 100;
spblas::index_t n = 10;
spblas::index_t k = 100;
spblas::index_t nnz_in = 10;

fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n\t### Running SpMM Example:");
fmt::print("\n\t###");
fmt::print("\n\t### Y = alpha * A * X");
fmt::print("\n\t###");
fmt::print("\n\t### with ");
fmt::print("\n\t### A, in CSR format, of size ({}, {}) with nnz = {}", m, n,
nnz_in);
fmt::print("\n\t### X, a dense matrix, of size ({}, {})", n, k);
fmt::print("\n\t### Y, a dense matrix, of size ({}, {})", m, k);
fmt::print("\n\t### using float and spblas::index_t (size = {} bytes)",
sizeof(spblas::index_t));
fmt::print("\n\t###########################################################"
"######################");
fmt::print("\n");

auto&& [values, rowptr, colind, shape, nnz] =
spblas::generate_csr<value_t, index_t, offset_t>(m, n, nnz_in);

value_t* d_values;
offset_t* d_rowptr;
index_t* d_colind;

CUDA_CHECK(cudaMalloc(&d_values, values.size() * sizeof(value_t)));
CUDA_CHECK(cudaMalloc(&d_rowptr, rowptr.size() * sizeof(offset_t)));
CUDA_CHECK(cudaMalloc(&d_colind, colind.size() * sizeof(index_t)));

CUDA_CHECK(cudaMemcpy(d_values, values.data(),
values.size() * sizeof(value_t), cudaMemcpyDefault));
CUDA_CHECK(cudaMemcpy(d_rowptr, rowptr.data(),
rowptr.size() * sizeof(offset_t), cudaMemcpyDefault));
CUDA_CHECK(cudaMemcpy(d_colind, colind.data(),
colind.size() * sizeof(index_t), cudaMemcpyDefault));

spblas::csr_view<value_t, index_t, offset_t> a(d_values, d_rowptr, d_colind,
shape, nnz);

// Scale every value of `a` by 5 in place.
// scale(5.f, a);

std::vector<value_t> x(n * k, 1);
std::vector<value_t> y(m * k, 0);

value_t* d_x;
value_t* d_y;

CUDA_CHECK(cudaMalloc(&d_x, x.size() * sizeof(value_t)));
CUDA_CHECK(cudaMalloc(&d_y, y.size() * sizeof(value_t)));

CUDA_CHECK(
cudaMemcpy(d_x, x.data(), x.size() * sizeof(value_t), cudaMemcpyDefault));
CUDA_CHECK(
cudaMemcpy(d_y, y.data(), y.size() * sizeof(value_t), cudaMemcpyDefault));

md::mdspan x_span(d_x, n, k);
md::mdspan y_span(d_y, m, k);

// Y = A * X
spblas::operation_info_t info;
spblas::multiply(info, a, x_span, y_span);

CUDA_CHECK(
cudaMemcpy(y.data(), d_y, y.size() * sizeof(value_t), cudaMemcpyDefault));

// CPU reference
std::vector<value_t> y_ref(m * k, 0);
for (index_t i = 0; i < m; i++) {
for (offset_t j = rowptr[i]; j < rowptr[i + 1]; j++) {
index_t col = colind[j];
value_t val = values[j];
for (index_t l = 0; l < k; l++) {
y_ref[i * k + l] += val * x[col * k + l];
}
}
}

bool failed = false;

for (size_t i = 0; i < y.size(); ++i) {
if (y[i] != y_ref[i]) {
fprintf(stderr, "Value mismatch at index %ld: y_ref[%ld] = %f, y[%ld] = %f\n", i, i, y_ref[i], i, y_ref[i]);
failed = true;
}
}

if (failed) {
fmt::print("\tValidation failed!\n");
}
else {
fmt::print("\tValidation succeeded!\n");
}

fmt::print("\tExample is completed!\n");

return failed;
}
77 changes: 77 additions & 0 deletions include/spblas/vendor/cusparse/detail/create_matrix_handle.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#pragma once

#include <cusparse.h>

#include <stdexcept>

#include <spblas/detail/view_inspectors.hpp>

namespace spblas {

namespace __cusparse {

template <matrix M>
requires __detail::is_csr_view_v<M>
cusparseSpMatDescr_t create_matrix_handle(M&& m) {
cusparseSpMatDescr_t mat_descr;
__cusparse::throw_if_error(cusparseCreateCsr(
&mat_descr, __backend::shape(m)[0], __backend::shape(m)[1],
m.values().size(), m.rowptr().data(), m.colind().data(),
m.values().data(), detail::cusparse_index_type_v<tensor_offset_t<M>>,
detail::cusparse_index_type_v<tensor_index_t<M>>,
CUSPARSE_INDEX_BASE_ZERO, detail::cuda_data_type_v<tensor_scalar_t<M>>));

return mat_descr;
}

template <matrix M>
requires __detail::is_csc_view_v<M>
cusparseSpMatDescr_t create_matrix_handle(M&& m) {
cusparseSpMatDescr_t mat_descr;
__cusparse::throw_if_error(cusparseCreateCsc(
&mat_descr, __backend::shape(m)[0], __backend::shape(m)[1],
m.values().size(), m.rowptr().data(), m.colind().data(),
m.values().data(), detail::cusparse_index_type_v<tensor_offset_t<M>>,
detail::cusparse_index_type_v<tensor_index_t<M>>,
CUSPARSE_INDEX_BASE_ZERO, detail::cuda_data_type_v<tensor_scalar_t<M>>));

return mat_descr;
}

template <matrix M>
requires __detail::has_base<M>
cusparseSpMatDescr_t create_matrix_handle(M&& m) {
return create_matrix_handle(m.base());
}

//
// Takes in a CSR or CSR_transpose (aka CSC) or CSC or CSC_transpose
// and returns the transpose value associated with it being represented
// in the CSR format (since oneMKL SYCL currently does not have CSC
// format
//
// CSR = CSR + nontrans
// CSR_transpose = CSR + trans
// CSC = CSR + trans
// CSC_transpose -> CSR + nontrans
//
// template <matrix M>
// oneapi::mkl::transpose get_transpose(M&& m) {
// static_assert(__detail::has_csr_base<M> || __detail::has_csc_base<M>);

// const bool conjugate = __detail::is_conjugated(m);
// if constexpr (__detail::has_csr_base<M>) {
// if (conjugate) {
// throw std::runtime_error(
// "oneMKL SYCL backend does not support conjugation for CSR views.");
// }
// return oneapi::mkl::transpose::nontrans;
// } else if constexpr (__detail::has_csc_base<M>) {
// return conjugate ? oneapi::mkl::transpose::conjtrans
// : oneapi::mkl::transpose::trans;
// }
// }

} // namespace __cusparse

} // namespace spblas
4 changes: 4 additions & 0 deletions include/spblas/vendor/cusparse/detail/detail.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#pragma once

#include "create_matrix_handle.hpp"
#include "get_matrix_handle.hpp"
44 changes: 44 additions & 0 deletions include/spblas/vendor/cusparse/detail/get_matrix_handle.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once

#include <cusparse.h>

#include <spblas/detail/log.hpp>
#include <spblas/detail/operation_info_t.hpp>
#include <spblas/detail/ranges.hpp>
#include <spblas/detail/view_inspectors.hpp>
#include <spblas/views/matrix_opt.hpp>

#include <spblas/vendor/cusparse/detail/create_matrix_handle.hpp>

namespace spblas {

namespace __cusparse {

template <matrix M>
cusparseSpMatDescr_t
get_matrix_handle(M&& m,
cusparseSpMatDescr_t handle = nullptr) {
if constexpr (__detail::is_matrix_opt_v<decltype(m)>) {
log_trace("using A as matrix_opt");

if (m.matrix_handle_ == nullptr) {
m.matrix_handle_ = create_matrix_handle(m.base());
}

return m.matrix_handle_;
} else if constexpr (__detail::has_base<M>) {
return get_matrix_handle(m.base(), handle);
} else if (handle != nullptr) {
log_trace("using A from operation_info_t");

return handle;
} else {
log_trace("using A as csr_base");

return create_matrix_handle(m);
}
}

} // namespace __cusparse

} // namespace spblas
65 changes: 65 additions & 0 deletions include/spblas/vendor/cusparse/detail/spgemm_state_t.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#pragma once

#include <cusparse.h>
#include <memory>

#include "abstract_operation_state.hpp"

namespace spblas {
namespace __cusparse {

class spgemm_state_t : public abstract_operation_state_t {
public:
spgemm_state_t() = default;
~spgemm_state_t() {
if (a_descr_) {
cusparseDestroySpMat(a_descr_);
}
if (b_descr_) {
cusparseDestroySpMat(b_descr_);
}
if (c_descr_) {
cusparseDestroySpMat(c_descr_);
}
if (spgemm_descr_) {
cusparseSpGEMM_destroyDescr(spgemm_descr_);
}
}

// Accessors for the descriptors
cusparseSpMatDescr_t a_descriptor() const {
return a_descr_;
}
cusparseDnVecDescr_t b_descriptor() const {
return b_descr_;
}
cusparseDnVecDescr_t c_descriptor() const {
return c_descr_;
}
cusparseSpGEMMDescr_t spgemm_descriptor() const {
return spgemm_descr_;
}

// Setters for the descriptors
void set_a_descriptor(cusparseSpMatDescr_t descr) {
a_descr_ = descr;
}
void set_b_descriptor(cusparseDnVecDescr_t descr) {
b_descr_ = descr;
}
void set_c_descriptor(cusparseDnVecDescr_t descr) {
c_descr_ = descr;
}
void set_spgemm_descriptor(cusparseSpGEMMDescr_t descr) {
spgemm_descr_ = descr;
}

private:
cusparseSpMatDescr_t a_descr_ = nullptr;
cusparseSpMatDescr_t b_descr_ = nullptr;
cusparseSpMatDescr_t c_descr_ = nullptr;
cusparseSpGEMMDescr_t spgemm_descr_ = nullptr;
};

} // namespace __cusparse
} // namespace spblas
55 changes: 55 additions & 0 deletions include/spblas/vendor/cusparse/detail/spmm_state_t.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#pragma once

#include <cusparse.h>
#include <memory>

#include "abstract_operation_state.hpp"

namespace spblas {
namespace __cusparse {

class spmm_state_t : public abstract_operation_state_t {
public:
spmm_state_t() = default;
~spmm_state_t() {
if (a_descr_) {
cusparseDestroySpMat(a_descr_);
}
if (x_descr_) {
cusparseDestroyDnMat(x_descr_);
}
if (y_descr_) {
cusparseDestroyDnMat(y_descr_);
}
}

// Accessors for the descriptors
cusparseSpMatDescr_t a_descriptor() const {
return a_descr_;
}
cusparseDnMatDescr_t x_descriptor() const {
return x_descr_;
}
cusparseDnMatDescr_t y_descriptor() const {
return y_descr_;
}

// Setters for the descriptors
void set_a_descriptor(cusparseSpMatDescr_t descr) {
a_descr_ = descr;
}
void set_x_descriptor(cusparseDnMatDescr_t descr) {
x_descr_ = descr;
}
void set_y_descriptor(cusparseDnMatDescr_t descr) {
y_descr_ = descr;
}

private:
cusparseSpMatDescr_t a_descr_ = nullptr;
cusparseDnMatDescr_t x_descr_ = nullptr;
cusparseDnMatDescr_t y_descr_ = nullptr;
};

} // namespace __cusparse
} // namespace spblas
1 change: 1 addition & 0 deletions include/spblas/vendor/cusparse/multiply.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#pragma once

#include "spmv_impl.hpp"
#include "spmm_impl.hpp"
Loading