diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3d17a73..b8425e2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -94,6 +94,33 @@ jobs: run: | ./build/test/gtest/spblas-tests + MATLAB_mex_checks: + runs-on: 'ubuntu-latest' + steps: + - uses: actions/checkout@v4 + - name: CMake Build and Install + run: | + cmake -B build -D CMAKE_INSTALL_PREFIX=install + make -C build -j `nproc` + cmake --install build/. + - name: Set up MATLAB + uses: matlab-actions/setup-matlab@v2.5.0 + with: + release: R2025b + - name: Build MEX APIs + uses: matlab-actions/run-command@v2.2.1 + with: + command: | + cd("examples/MATLAB_MEX"); + build_mex_APIs("../../install/include") + - name: Run MEX API Tests + uses: matlab-actions/run-command@v2.2.1 + with: + command: | + cd("examples/MATLAB_MEX"); + results = runtests("testSparseBLASMexAPIs", "Verbosity", 0) + assertSuccess(results); + aocl: runs-on: 'cpu_amd' steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2bd722f..539049f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: rev: v16.0.6 hooks: - id: clang-format + types_or: [c, c++, cuda, inc] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 diff --git a/examples/MATLAB_MEX/build_mex_APIs.m b/examples/MATLAB_MEX/build_mex_APIs.m new file mode 100644 index 0000000..0d611b7 --- /dev/null +++ b/examples/MATLAB_MEX/build_mex_APIs.m @@ -0,0 +1,29 @@ +function build_mex_APIs(incl_path, verbose, debug) +%BUILD_MEX_APIS - Function to build all available mex APIs +% +% First input must be the path to the SparseBLAS INCLUDE folder. +% Second and third are optional logical inputs to activate VERBOSE or and +% DEBUG mode. +% +% The calling syntaxes are: +% build_mex_APIs("PATH_TO_SparseBLAS_INCLUDE") +% build_mex_APIs("PATH_TO_SparseBLAS_INCLUDE", true) +% build_mex_APIs("PATH_TO_SparseBLAS_INCLUDE", false, true) + +% Set default options +opts = {['-I' char(incl_path)], "-O", "-R2018a", "CXXFLAGS=$CFLAGS -std=c++23"}; + +% Parse optional VERBOSE option +if nargin > 1 && verbose + opts = [opts, "-v"]; +end + +% Parse optional DEBUG option +if nargin > 2 && debug + opts = [opts, "-g"]; +end + +% Compile all APIs +mex("simple_spmv_mex.cpp", opts{:}); +mex("simple_spmm_mex.cpp", opts{:}); +end diff --git a/examples/MATLAB_MEX/simple_spmm_mex.cpp b/examples/MATLAB_MEX/simple_spmm_mex.cpp new file mode 100644 index 0000000..dcc722b --- /dev/null +++ b/examples/MATLAB_MEX/simple_spmm_mex.cpp @@ -0,0 +1,133 @@ +// Includes from SparseBLAS +#include + +// Includes for MEX +#include +#include + +// General includes +#include // Support complex inputs + +template +void spmmDriver(mxArray* mxY, const mxArray* mxA, const mxArray* mxX, + const mxArray* mxAlpha) { + // Gather dimensions + mwIndex m = mxGetM(mxA); + mwIndex k = mxGetN(mxA); + mwIndex n = mxGetN(mxX); + + // Fill csc_view with: + // - T* values + // - mwIndex* colptr + // - mwIndex* rowind + // - {mwIndex m, mwIndex k} (shape) + // - mwIndex nnz + spblas::csc_view A(static_cast(mxGetData(mxA)), + mxGetJc(mxA), mxGetIr(mxA), {m, k}, + mxGetJc(mxA)[k]); + + // Wrap X in an mdspan of size k-by-n + spblas::mdspan_col_major X(static_cast(mxGetData(mxX)), k, n); + + // Wrap output Y in an mdspan of size m-by-n + spblas::mdspan_col_major Y(static_cast(mxGetData(mxY)), m, n); + + // Store and apply scaling factor alpha, if provided and not empty + T alpha = T(1); + if (mxAlpha != nullptr && !mxIsEmpty(mxAlpha)) { + // We don't use mxGetScalar as it doesn't work for complex + alpha = *(static_cast(mxGetData(mxAlpha))); + } + auto alpha_A = spblas::scaled(alpha, A); + + // Y = (alpha * A) * X + spblas::multiply(alpha_A, X, Y); +} + +// Y = (alpha *) A * X +// prhs[0] = A, prhs[1] = X (optional: prhs[2] = alpha) +// plhs[0] = Y +void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { + + // General input checking + if (nrhs < 2 || nrhs > 3) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:WrongNumberOfInputs", + "Function needs 2 or 3 inputs."); + } + if (nlhs > 1) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:WrongNumberOfOutputs", + "Function returns only 1 output."); + } + if (mxGetClassID(prhs[0]) != mxGetClassID(prhs[1]) || + ((nrhs == 3) && mxGetClassID(prhs[1]) != mxGetClassID(prhs[2]))) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:ClassMismatch", + "All inputs must have matching type."); + } + if (!mxIsDouble(prhs[0]) && !mxIsSingle(prhs[0])) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:NonFloat", + "All inputs must be single or double."); + } + if (!mxIsSparse(prhs[0])) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:FirstInputNotSparse", + "First input must be sparse."); + } + if (mxIsSparse(prhs[1])) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:SecondInputNotDense", + "Second input must be dense."); + } + + // Reference SparseBLAS can handle inputs with mixed complexity, + // however, the vendor implementations need all inputs of the same + // complexity, hence, this example also insists on having matching + // complexity. + if (mxIsComplex(prhs[0]) != mxIsComplex(prhs[1]) || + ((nrhs == 3) && mxIsComplex(prhs[1]) != mxIsComplex(prhs[2]))) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:ComplexityMismatch", + "All inputs must have matching complexity."); + } + + // Gather dimensions + mwIndex m = mxGetM(prhs[0]); + mwIndex k = mxGetN(prhs[0]); + mwIndex n = mxGetN(prhs[1]); + + // Check dimensions of second input + if (mxGetM(prhs[1]) != k) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:InnerDimWrong", + "Second input must be an array with k rows, " + "i.e., number of columns of first input."); + } + + // Calculate in complex (we check for matching complexity above) + bool isCmplx = mxIsComplex(prhs[0]); + + // Type dispatch for double or single, each as real or complex flavor + if (mxIsDouble(prhs[0])) { + if (isCmplx) { + plhs[0] = mxCreateNumericMatrix(m, n, mxDOUBLE_CLASS, mxCOMPLEX); + spmmDriver>(plhs[0], prhs[0], prhs[1], + nrhs == 3 ? prhs[2] : nullptr); + } else { + plhs[0] = mxCreateNumericMatrix(m, n, mxDOUBLE_CLASS, mxREAL); + spmmDriver(plhs[0], prhs[0], prhs[1], + nrhs == 3 ? prhs[2] : nullptr); + } + } else { + mxAssert(mxIsSingle(prhs[0]), "Invalid data type"); + if (isCmplx) { + plhs[0] = mxCreateNumericMatrix(m, n, mxSINGLE_CLASS, mxCOMPLEX); + spmmDriver>(plhs[0], prhs[0], prhs[1], + nrhs == 3 ? prhs[2] : nullptr); + } else { + plhs[0] = mxCreateNumericMatrix(m, n, mxSINGLE_CLASS, mxREAL); + spmmDriver(plhs[0], prhs[0], prhs[1], + nrhs == 3 ? prhs[2] : nullptr); + } + } +} + +// Compile from within MATLAB via: +// mex simple_spmm_mex.cpp -R2018a -I{PATH_TO_SparseBLAS_INCLUDE} +// 'CXXFLAGS=$CFLAGS -std=c++20' +// +// Add '-g' to build in Debug mode if needed (activates asserts) diff --git a/examples/MATLAB_MEX/simple_spmm_mex.m b/examples/MATLAB_MEX/simple_spmm_mex.m new file mode 100644 index 0000000..3a54059 --- /dev/null +++ b/examples/MATLAB_MEX/simple_spmm_mex.m @@ -0,0 +1,18 @@ +% simple_spmm_mex - Sparse matrix times dense matrix multiplication +% simple_smpm_mex.c - example in MATLAB External Interfaces +% +% Multiplies a (potentially scaled by a scalar alpha) sparse MxK matrix +% with a dense KxN matrix and outputs a dense MxN matrix: +% Y = A * X or Y = alpha * A * X +% +% The calling syntaxes are: +% Y = simple_smpv_mex(A, X) +% Y = simple_smpv_mex(A, X, alpha) +% +% The following restrictions apply: +% * A must be sparse +% * X must be dense +% * Number of columns in A and rows in X must match +% * All inputs must have the same data type and complexity +% +% This is a MEX-file for MATLAB. diff --git a/examples/MATLAB_MEX/simple_spmv_mex.cpp b/examples/MATLAB_MEX/simple_spmv_mex.cpp new file mode 100644 index 0000000..0b64a2c --- /dev/null +++ b/examples/MATLAB_MEX/simple_spmv_mex.cpp @@ -0,0 +1,134 @@ +// Includes from SparseBLAS +#include + +// Includes for MEX +#include +#include + +// General includes +#include // Support complex inputs + +template +void spmvDriver(mxArray* mxY, const mxArray* mxA, const mxArray* mxX, + const mxArray* mxAlpha) { + + // Gather dimensions + mwIndex m = mxGetM(mxA); + mwIndex n = mxGetN(mxA); + + // Fill csc_view with: + // - T* values + // - mwIndex* colptr + // - mwIndex* rowind + // - {mwIndex m, mwIndex n} (shape) + // - mwIndex nnz + spblas::csc_view A(static_cast(mxGetData(mxA)), + mxGetJc(mxA), mxGetIr(mxA), {m, n}, + mxGetJc(mxA)[n]); + // Wrap x in a span of length n + std::span x(static_cast(mxGetData(mxX)), n); + + // Wrap output y in a span of length m + std::span y(static_cast(mxGetData(mxY)), m); + + // Store and apply scaling factor alpha, if provided and not empty + T alpha = T(1); + if (mxAlpha != nullptr && !mxIsEmpty(mxAlpha)) { + // We don't use mxGetScalar as it doesn't work for complex + alpha = *(static_cast(mxGetData(mxAlpha))); + } + auto alpha_A = spblas::scaled(alpha, A); + + // y = (alpha * A) * x + spblas::multiply(alpha_A, x, y); +} + +// y = (alpha *) A * x +// prhs[0] = A, prhs[1] = x (optional: prhs[2] = alpha) +// plhs[0] = y +void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { + + // General input checking + if (nrhs < 2 || nrhs > 3) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:WrongNumberOfInputs", + "Function needs 2 or 3 inputs."); + } + if (nlhs > 1) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:WrongNumberOfOutputs", + "Function returns only 1 output."); + } + if (mxGetClassID(prhs[0]) != mxGetClassID(prhs[1]) || + ((nrhs == 3) && mxGetClassID(prhs[1]) != mxGetClassID(prhs[2]))) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:ClassMismatch", + "All inputs must have matching type."); + } + if (!mxIsDouble(prhs[0]) && !mxIsSingle(prhs[0])) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:NonFloat", + "All inputs must be single or double."); + } + if (!mxIsDouble(prhs[0]) && !mxIsSingle(prhs[0])) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:NonFloat", + "All inputs must be single or double."); + } + if (!mxIsSparse(prhs[0])) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:FirstInputNotSparse", + "First input must be sparse."); + } + if (mxIsSparse(prhs[1])) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:SecondInputNotDense", + "Second input must be dense."); + } + + // Reference SparseBLAS can handle inputs with mixed complexity, + // however, the vendor implementations need all inputs of the same + // complexity, hence, this example also insists on having matching + // complexity. + if (mxIsComplex(prhs[0]) != mxIsComplex(prhs[1]) || + ((nrhs == 3) && mxIsComplex(prhs[1]) != mxIsComplex(prhs[2]))) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:ComplexityMismatch", + "All inputs must have matching complexity."); + } + + // Gather dimensions + mwIndex m = mxGetM(prhs[0]); + mwIndex n = mxGetN(prhs[0]); + // Check dimensions of second input + if ((mxGetM(prhs[1]) != n) && (mxGetN(prhs[1]) != 1)) { + mexErrMsgIdAndTxt("SparseBLAS_Mex:InnerDimWrong", + "Second input must be column vector of length n, " + "i.e., number of columns of first input."); + } + + // Calculate in complex (we check for matching complexity above) + bool isCmplx = mxIsComplex(prhs[0]); + + // Type dispatch for double or single, each as real or complex flavor + if (mxIsDouble(prhs[0])) { + if (isCmplx) { + plhs[0] = mxCreateNumericMatrix(m, 1, mxDOUBLE_CLASS, mxCOMPLEX); + spmvDriver>(plhs[0], prhs[0], prhs[1], + nrhs == 3 ? prhs[2] : nullptr); + } else { + plhs[0] = mxCreateNumericMatrix(m, 1, mxDOUBLE_CLASS, mxREAL); + spmvDriver(plhs[0], prhs[0], prhs[1], + nrhs == 3 ? prhs[2] : nullptr); + } + } else { + mxAssert(mxIsSingle(prhs[0]), "Invalid data type"); + if (isCmplx) { + plhs[0] = mxCreateNumericMatrix(m, 1, mxSINGLE_CLASS, mxCOMPLEX); + spmvDriver>(plhs[0], prhs[0], prhs[1], + nrhs == 3 ? prhs[2] : nullptr); + } else { + plhs[0] = mxCreateNumericMatrix(m, 1, mxSINGLE_CLASS, mxREAL); + spmvDriver(plhs[0], prhs[0], prhs[1], + nrhs == 3 ? prhs[2] : nullptr); + } + } +} + +// Compile from within MATLAB via: +// mex simple_spmv_mex.cpp -R2018a -I{PATH_TO_SparseBLAS_INCLUDE} +// 'CXXFLAGS=$CFLAGS -std=c++20' +// +// Add '-g' to build in Debug mode if needed (activates asserts) diff --git a/examples/MATLAB_MEX/simple_spmv_mex.m b/examples/MATLAB_MEX/simple_spmv_mex.m new file mode 100644 index 0000000..09e2196 --- /dev/null +++ b/examples/MATLAB_MEX/simple_spmv_mex.m @@ -0,0 +1,18 @@ +% simple_spmv_mex - Sparse matrix times dense vector multiplication +% simple_smpv_mex.c - example in MATLAB External Interfaces +% +% Multiplies a (potentially scaled by a scalar alpha) sparse MxN matrix +% with a dense Nx1 column vector and outputs a dense Mx1 column vector: +% y = A * x or y = alpha * A * x +% +% The calling syntaxes are: +% y = simple_smpv_mex(A, x) +% y = simple_smpv_mex(A, x, alpha) +% +% The following restrictions apply: +% * A must be sparse +% * x must be dense column vector +% * Number of columns in A and rows in x must match +% * All inputs must have the same data type and complexity +% +% This is a MEX-file for MATLAB. diff --git a/examples/MATLAB_MEX/testSparseBLASMexAPIs.m b/examples/MATLAB_MEX/testSparseBLASMexAPIs.m new file mode 100644 index 0000000..4e44f6e --- /dev/null +++ b/examples/MATLAB_MEX/testSparseBLASMexAPIs.m @@ -0,0 +1,155 @@ +classdef testSparseBLASMexAPIs < matlab.unittest.TestCase +% testSparseBLASMexAPIs Tests for SparseBLAS MEX APIs + + methods(TestMethodSetup) + function initializeRNG(~) + rng(0,'twister'); + end + end + + properties(TestParameter) + % Loop over different sizes + sizesToTest = struct(... + 'empty1', 0, ... + 'empty2', 1, ... + 'tiny1', 2, ... + 'tiny2', 5, ... + 'small1' , 1e1, ... + 'small2', 5e1, ... + 'medium1', 1e2, ... + 'medium2', 5e2, ... + 'large1', 1e3, ... + 'large2', 5e3); + % Loop over data types and complexities + complexity = {'real', 'complex'}; + datatypes = {'double', 'single'}; + % Loop over various shapes of sparse and dense inputs + shape = {'square', 'tall', 'wide'}; + numberOfRHS = struct(... + 'singleColumn', 1, ... + 'doubleColumn', 2, ... + 'manyColumns', 10, ... + 'veryManyColumns', 100); + % Loop over different scalar scalings + alpha = struct(... + 'none', [], ... + 'neutral', 1.0, ... + 'upScale', 2.3, ... + 'upScaleNeg', -4.3, ... + 'downScale', 0.23, ... + 'downScaleNeg', -0.43); + end + + methods (Test) + % Test each API + function simpleSPMV(testCase, sizesToTest, datatypes, complexity, shape, alpha) + %% Create data + nRhs = 1; + [A, x] = createData(sizesToTest, nRhs, datatypes, complexity, shape); + + %% Calculate reference solution, adapted to MATLAB's special + % case treatment + if isempty(alpha) + y_exp = A*x; + else + alpha = cast(alpha, datatypes); + if strcmp(complexity, 'complex') + alpha = complex(alpha, alpha); + end + y_exp = alpha*A*x; + end + + % If either input the '*' is scalar, MATLAB calls '.*' which + % returns sparse results for dense and sparse mixed inputs, + % hence, making the result dense as SparseBLAS doesn't special + % case these situations. + if isscalar(A) || isscalar(x) + y_exp = full(y_exp); + end + + % MATLAB strips all-zero imaginary parts during '*'. SparseBLAS + % does not, hence, make results complex again if complexity is + % set to 'complex'. + if strcmp(complexity, 'complex') && isreal(y_exp) + y_exp = complex(y_exp); + end + + %% Calculate solution via SparseBLAS MEX APIs + if isempty(alpha) + y_act = simple_spmv_mex(A, x); + else + y_act = simple_spmv_mex(A, x, alpha); + end + + %% Verify results + testCase.verifyEqual(y_act, y_exp); + end + + function simpleSPMM(testCase, sizesToTest, numberOfRHS, datatypes, complexity, shape, alpha) + %% Create data + [A, X] = createData(sizesToTest, numberOfRHS, datatypes, complexity, shape); + + %% Calculate reference solution, adapted to MATLAB's special + % case treatment + if isempty(alpha) + y_exp = A*X; + else + alpha = cast(alpha, datatypes); + if strcmp(complexity, 'complex') + alpha = complex(alpha, alpha); + end + y_exp = alpha*A*X; + end + + % If either input the '*' is scalar, MATLAB calls '.*' which + % returns sparse results for dense and sparse mixed inputs, + % hence, making the result dense as SparseBLAS doesn't special + % case these situations. + if isscalar(A) || isscalar(X) + y_exp = full(y_exp); + end + + % MATLAB strips all-zero imaginary parts during '*'. SparseBLAS + % does not, hence, make results complex again if complexity is + % set to 'complex'. + if strcmp(complexity, 'complex') && isreal(y_exp) + y_exp = complex(y_exp); + end + + %% Calculate solution via SparseBLAS MEX APIs + if isempty(alpha) + y_act = simple_spmm_mex(A, X); + else + y_act = simple_spmm_mex(A, X, alpha); + end + + %% Verify results + testCase.verifyEqual(y_act, y_exp); + end + end + +end + +function [A, X] = createData(n, nRhs, datatypes, complexity, shape) +% We use this routine to create sparse A and dense X +lesser_n = floor(n/2); +switch shape + case 'wide' + sz = [lesser_n, n]; + case 'tall' + sz = [n, lesser_n]; + case 'square' + sz = [n, n]; +end + +switch complexity + case 'real' + A = sprand(sz(1), sz(2), 0.01, datatypes); + X = rand(sz(2), nRhs, datatypes); + case 'complex' + A = complex(sprand(sz(1), sz(2), 0.01, datatypes), ... + sprand(sz(1), sz(2), 0.01, datatypes)); + X = complex(rand(sz(2), nRhs, datatypes), ... + rand(sz(2), nRhs, datatypes)); +end +end diff --git a/examples/matrix_opt_example.cpp b/examples/matrix_opt_example.cpp index 0585c93..2bf6d92 100644 --- a/examples/matrix_opt_example.cpp +++ b/examples/matrix_opt_example.cpp @@ -5,7 +5,6 @@ int main(int argc, char** argv) { using namespace spblas; - namespace md = spblas::__mdspan; using T = float; @@ -40,8 +39,8 @@ int main(int argc, char** argv) { std::vector x_values(k * n, 1); std::vector y_values(m * n, 0); - md::mdspan x(x_values.data(), k, n); - md::mdspan y(y_values.data(), m, n); + mdspan_row_major x(x_values.data(), k, n); + mdspan_row_major y(y_values.data(), m, n); auto a_view = scaled(2.f, a); diff --git a/examples/simple_spmm.cpp b/examples/simple_spmm.cpp index 9440e6c..436ae29 100644 --- a/examples/simple_spmm.cpp +++ b/examples/simple_spmm.cpp @@ -5,7 +5,6 @@ int main(int argc, char** argv) { using namespace spblas; - namespace md = spblas::__mdspan; using T = float; @@ -38,8 +37,8 @@ int main(int argc, char** argv) { std::vector x_values(k * n, 1); std::vector y_values(m * n, 0); - md::mdspan x(x_values.data(), k, n); - md::mdspan y(y_values.data(), m, n); + mdspan_row_major x(x_values.data(), k, n); + mdspan_row_major y(y_values.data(), m, n); auto a_view = scaled(2.f, a); diff --git a/examples/spmm_csc.cpp b/examples/spmm_csc.cpp index 6b7a69c..1e616e9 100644 --- a/examples/spmm_csc.cpp +++ b/examples/spmm_csc.cpp @@ -6,7 +6,6 @@ int main(int argc, char** argv) { using namespace spblas; - namespace md = spblas::__mdspan; using T = float; @@ -39,8 +38,8 @@ int main(int argc, char** argv) { std::vector x_values(k * n, 1); std::vector y_values(m * n, 0); - md::mdspan x(x_values.data(), k, n); - md::mdspan y(y_values.data(), m, n); + mdspan_row_major x(x_values.data(), k, n); + mdspan_row_major y(y_values.data(), m, n); // y = A * (alpha * x) multiply(a, scaled(2.f, x), y); diff --git a/include/spblas/detail/mdspan.hpp b/include/spblas/detail/mdspan.hpp index 8ae10b0..c8717e8 100644 --- a/include/spblas/detail/mdspan.hpp +++ b/include/spblas/detail/mdspan.hpp @@ -26,3 +26,17 @@ static_assert(false, "spblas requires mdspan. Compile with a C++23 compiler " "or download the std/experimental implementation."); #endif + +namespace spblas { +// Define templated aliases for col_major (layout_left) and row_major +// (layout_right) mdspan types. +template +using mdspan_col_major = __mdspan::mdspan< + T, __mdspan::extents, + __mdspan::layout_left>; + +template +using mdspan_row_major = __mdspan::mdspan< + T, __mdspan::extents, + __mdspan::layout_right>; +} // namespace spblas diff --git a/test/gtest/CMakeLists.txt b/test/gtest/CMakeLists.txt index 8d308e6..e59c42d 100644 --- a/test/gtest/CMakeLists.txt +++ b/test/gtest/CMakeLists.txt @@ -11,7 +11,8 @@ if (SPBLAS_CPU_BACKEND) spgemm_csr_csc.cpp add_test.cpp transpose_test.cpp - triangular_solve_test.cpp) + triangular_solve_test.cpp + mdspan_overlays.cpp) if (ENABLE_ONEMKL_SYCL OR SPBLAS_REFERENCE_BACKEND) list(APPEND TEST_SOURCES conjugate_test.cpp) diff --git a/test/gtest/conjugate_test.cpp b/test/gtest/conjugate_test.cpp index cd6ecd7..bde55f5 100644 --- a/test/gtest/conjugate_test.cpp +++ b/test/gtest/conjugate_test.cpp @@ -91,8 +91,6 @@ TEST(Conjugate, SpMV_VectorConjugated) { } TEST(Conjugate, SpMM_MatrixConjugated) { - namespace md = spblas::__mdspan; - for (auto&& [m, k, nnz] : util::dims) { for (auto n : {1, 8, 32}) { auto [values, rowptr, colind, shape, _] = @@ -102,8 +100,8 @@ TEST(Conjugate, SpMM_MatrixConjugated) { auto [b_values, b_shape] = spblas::generate_dense(k, n); std::vector c_values(m * n, T(0.0f, 0.0f)); - md::mdspan b(b_values.data(), k, n); - md::mdspan c(c_values.data(), m, n); + spblas::mdspan_row_major b(b_values.data(), k, n); + spblas::mdspan_row_major c(c_values.data(), m, n); spblas::multiply(spblas::conjugated(a), b, c); @@ -126,8 +124,6 @@ TEST(Conjugate, SpMM_MatrixConjugated) { } TEST(Conjugate, SpMM_DenseConjugated) { - namespace md = spblas::__mdspan; - for (auto&& [m, k, nnz] : util::dims) { for (auto n : {1, 8, 32}) { auto [values, rowptr, colind, shape, _] = @@ -137,8 +133,8 @@ TEST(Conjugate, SpMM_DenseConjugated) { auto [b_values, b_shape] = spblas::generate_dense(k, n); std::vector c_values(m * n, T(0.0f, 0.0f)); - md::mdspan b(b_values.data(), k, n); - md::mdspan c(c_values.data(), m, n); + spblas::mdspan_row_major b(b_values.data(), k, n); + spblas::mdspan_row_major c(c_values.data(), m, n); spblas::multiply(a, spblas::conjugated(b), c); @@ -238,8 +234,6 @@ TEST(Conjugate, SpMV_MatrixConjugated) { } TEST(Conjugate, SpMM_MatrixConjugated) { - namespace md = spblas::__mdspan; - for (auto&& [m, k, nnz] : util::dims) { for (auto n : {1, 8, 32}) { auto [values, colptr, rowind, shape, _] = @@ -249,8 +243,8 @@ TEST(Conjugate, SpMM_MatrixConjugated) { auto [b_values, b_shape] = spblas::generate_dense(k, n); std::vector c_values(m * n, T(0.0f, 0.0f)); - md::mdspan b(b_values.data(), k, n); - md::mdspan c(c_values.data(), m, n); + spblas::mdspan_row_major b(b_values.data(), k, n); + spblas::mdspan_row_major c(c_values.data(), m, n); spblas::multiply(spblas::conjugated(a), b, c); diff --git a/test/gtest/mdspan_overlays.cpp b/test/gtest/mdspan_overlays.cpp new file mode 100644 index 0000000..7581c4e --- /dev/null +++ b/test/gtest/mdspan_overlays.cpp @@ -0,0 +1,104 @@ +#include + +#include "util.hpp" +#include + +// Accessing the data inside mdspan differs between different mdspan +// implementations. The portable way is quite heavy and the following helper +// makes the tests themselves easier to read. +template +decltype(auto) md_at(T& m, typename T::index_type i, typename T::index_type j) { +#if defined(__cpp_multidimensional_subscript) + return m[i, j]; +#else + return m(i, j); +#endif +} + +TEST(Mdspan, positive_row_major) { + using T = float; + using I = spblas::index_t; + + for (auto m : {1, 2, 4}) { + for (auto n : {1, 2, 4}) { + auto [b_values, b_shape] = spblas::generate_dense(m, n); + spblas::mdspan_row_major b(b_values.data(), m, n); + + // Traverse by row in inner loop to immitade row-major + T* tmp = b_values.data(); + for (I i = 0; i < m; ++i) { + for (I j = 0; j < n; ++j) { + EXPECT_EQ(md_at(b, i, j), *(tmp++)); + } + } + } + } +} + +TEST(Mdspan, postive_col_major) { + using T = double; + using I = spblas::index_t; + + for (auto m : {1, 2, 4}) { + for (auto n : {1, 2, 4}) { + auto [b_values, b_shape] = spblas::generate_dense(m, n); + spblas::mdspan_col_major b(b_values.data(), m, n); + + // Traverse by column in inner loop to immitade col-major + T* tmp = b_values.data(); + for (I j = 0; j < n; ++j) { + for (I i = 0; i < m; ++i) { + EXPECT_EQ(md_at(b, i, j), *(tmp++)); + } + } + } + } +} + +TEST(Mdspan, negative_row_major) { + using T = double; + using I = int32_t; + + for (auto [m, n] : {std::pair{2, 4}, std::pair{4, 2}}) { + + auto [b_values, b_shape] = spblas::generate_dense(m, n); + spblas::mdspan_row_major b(b_values.data(), m, n); + + // Traverse by column in inner loop to not immitade row-major + T* tmp = b_values.data(); + for (I j = 0; j < n; ++j) { + for (I i = 0; i < m; ++i) { + // Skip first and last element + if ((i == 0 && j == 0) || (i == m - 1 && j == n - 1)) { + tmp++; + continue; + } + EXPECT_NE(md_at(b, i, j), *(tmp++)); + } + } + } +} + +TEST(Mdspan, negative_col_major) { + using T = int32_t; + using I = int64_t; + + for (auto [m, n] : {std::pair{2, 4}, std::pair{4, 2}}) { + + auto [b_values, b_shape] = spblas::generate_dense(m, n); + spblas::mdspan_col_major b(b_values.data(), m, n); + + // Traverse by row in inner loop to not immitade col-major + T* tmp = b_values.data(); + for (I i = 0; i < m; ++i) { + for (I j = 0; j < n; ++j) { + // Skip first and last element + if ((i == 0 && j == 0) || (i == m - 1 && j == n - 1)) { + tmp++; + continue; + } + EXPECT_NE(md_at(b, i, j), *(tmp++)); + } + } + } +} diff --git a/test/gtest/spmm_test.cpp b/test/gtest/spmm_test.cpp index 554a192..118a4d0 100644 --- a/test/gtest/spmm_test.cpp +++ b/test/gtest/spmm_test.cpp @@ -4,8 +4,6 @@ #include TEST(CsrView, SpMM) { - namespace md = spblas::__mdspan; - using T = float; using I = spblas::index_t; @@ -20,8 +18,8 @@ TEST(CsrView, SpMM) { std::vector c_values(m * n, 0); - md::mdspan b(b_values.data(), k, n); - md::mdspan c(c_values.data(), m, n); + spblas::mdspan_row_major b(b_values.data(), k, n); + spblas::mdspan_row_major c(c_values.data(), m, n); spblas::multiply(a, b, c);