Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,33 @@ jobs:
run: |
./build/test/gtest/spblas-tests

MATLAB_mex_checks:
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v4
- name: CMake Build and Install
run: |
cmake -B build -D CMAKE_INSTALL_PREFIX=install
make -C build -j `nproc`
cmake --install build/.
- name: Set up MATLAB
uses: matlab-actions/setup-matlab@v2.5.0
with:
release: R2025b
- name: Build MEX APIs
uses: matlab-actions/run-command@v2.2.1
with:
command: |
cd("examples/MATLAB_MEX");
build_mex_APIs("../../install/include")
- name: Run MEX API Tests
uses: matlab-actions/run-command@v2.2.1
with:
command: |
cd("examples/MATLAB_MEX");
results = runtests("testSparseBLASMexAPIs", "Verbosity", 0)
assertSuccess(results);

aocl:
runs-on: 'cpu_amd'
steps:
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ repos:
rev: v16.0.6
hooks:
- id: clang-format
types_or: [c, c++, cuda, inc]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
Expand Down
29 changes: 29 additions & 0 deletions examples/MATLAB_MEX/build_mex_APIs.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
function build_mex_APIs(incl_path, verbose, debug)
%BUILD_MEX_APIS - Function to build all available mex APIs
%
% First input must be the path to the SparseBLAS INCLUDE folder.
% Second and third are optional logical inputs to activate VERBOSE or and
% DEBUG mode.
%
% The calling syntaxes are:
% build_mex_APIs("PATH_TO_SparseBLAS_INCLUDE")
% build_mex_APIs("PATH_TO_SparseBLAS_INCLUDE", true)
% build_mex_APIs("PATH_TO_SparseBLAS_INCLUDE", false, true)

% Set default options
opts = {['-I' char(incl_path)], "-O", "-R2018a", "CXXFLAGS=$CFLAGS -std=c++23"};

% Parse optional VERBOSE option
if nargin > 1 && verbose
opts = [opts, "-v"];
end

% Parse optional DEBUG option
if nargin > 2 && debug
opts = [opts, "-g"];
end

% Compile all APIs
mex("simple_spmv_mex.cpp", opts{:});
mex("simple_spmm_mex.cpp", opts{:});
end
133 changes: 133 additions & 0 deletions examples/MATLAB_MEX/simple_spmm_mex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Includes from SparseBLAS
#include <spblas/spblas.hpp>

// Includes for MEX
#include <matrix.h>
#include <mex.h>

// General includes
#include <complex> // Support complex inputs

template <typename T>
void spmmDriver(mxArray* mxY, const mxArray* mxA, const mxArray* mxX,
const mxArray* mxAlpha) {
// Gather dimensions
mwIndex m = mxGetM(mxA);
mwIndex k = mxGetN(mxA);
mwIndex n = mxGetN(mxX);

// Fill csc_view with:
// - T* values
// - mwIndex* colptr
// - mwIndex* rowind
// - {mwIndex m, mwIndex k} (shape)
// - mwIndex nnz
spblas::csc_view<const T, mwIndex> A(static_cast<const T*>(mxGetData(mxA)),
mxGetJc(mxA), mxGetIr(mxA), {m, k},
mxGetJc(mxA)[k]);

// Wrap X in an mdspan of size k-by-n
spblas::mdspan_col_major<T, mwIndex> X(static_cast<T*>(mxGetData(mxX)), k, n);

// Wrap output Y in an mdspan of size m-by-n
spblas::mdspan_col_major<T, mwIndex> Y(static_cast<T*>(mxGetData(mxY)), m, n);

// Store and apply scaling factor alpha, if provided and not empty
T alpha = T(1);
if (mxAlpha != nullptr && !mxIsEmpty(mxAlpha)) {
// We don't use mxGetScalar as it doesn't work for complex
alpha = *(static_cast<T*>(mxGetData(mxAlpha)));
}
auto alpha_A = spblas::scaled(alpha, A);

// Y = (alpha * A) * X
spblas::multiply(alpha_A, X, Y);
}

// Y = (alpha *) A * X
// prhs[0] = A, prhs[1] = X (optional: prhs[2] = alpha)
// plhs[0] = Y
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {

// General input checking
if (nrhs < 2 || nrhs > 3) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:WrongNumberOfInputs",
"Function needs 2 or 3 inputs.");
}
if (nlhs > 1) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:WrongNumberOfOutputs",
"Function returns only 1 output.");
}
if (mxGetClassID(prhs[0]) != mxGetClassID(prhs[1]) ||
((nrhs == 3) && mxGetClassID(prhs[1]) != mxGetClassID(prhs[2]))) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:ClassMismatch",
"All inputs must have matching type.");
}
if (!mxIsDouble(prhs[0]) && !mxIsSingle(prhs[0])) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:NonFloat",
"All inputs must be single or double.");
}
if (!mxIsSparse(prhs[0])) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:FirstInputNotSparse",
"First input must be sparse.");
}
if (mxIsSparse(prhs[1])) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:SecondInputNotDense",
"Second input must be dense.");
}

// Reference SparseBLAS can handle inputs with mixed complexity,
// however, the vendor implementations need all inputs of the same
// complexity, hence, this example also insists on having matching
// complexity.
if (mxIsComplex(prhs[0]) != mxIsComplex(prhs[1]) ||
((nrhs == 3) && mxIsComplex(prhs[1]) != mxIsComplex(prhs[2]))) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:ComplexityMismatch",
"All inputs must have matching complexity.");
}

// Gather dimensions
mwIndex m = mxGetM(prhs[0]);
mwIndex k = mxGetN(prhs[0]);
mwIndex n = mxGetN(prhs[1]);

// Check dimensions of second input
if (mxGetM(prhs[1]) != k) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:InnerDimWrong",
"Second input must be an array with k rows, "
"i.e., number of columns of first input.");
}

// Calculate in complex (we check for matching complexity above)
bool isCmplx = mxIsComplex(prhs[0]);

// Type dispatch for double or single, each as real or complex flavor
if (mxIsDouble(prhs[0])) {
if (isCmplx) {
plhs[0] = mxCreateNumericMatrix(m, n, mxDOUBLE_CLASS, mxCOMPLEX);
spmmDriver<std::complex<double>>(plhs[0], prhs[0], prhs[1],
nrhs == 3 ? prhs[2] : nullptr);
} else {
plhs[0] = mxCreateNumericMatrix(m, n, mxDOUBLE_CLASS, mxREAL);
spmmDriver<double>(plhs[0], prhs[0], prhs[1],
nrhs == 3 ? prhs[2] : nullptr);
}
} else {
mxAssert(mxIsSingle(prhs[0]), "Invalid data type");
if (isCmplx) {
plhs[0] = mxCreateNumericMatrix(m, n, mxSINGLE_CLASS, mxCOMPLEX);
spmmDriver<std::complex<float>>(plhs[0], prhs[0], prhs[1],
nrhs == 3 ? prhs[2] : nullptr);
} else {
plhs[0] = mxCreateNumericMatrix(m, n, mxSINGLE_CLASS, mxREAL);
spmmDriver<float>(plhs[0], prhs[0], prhs[1],
nrhs == 3 ? prhs[2] : nullptr);
}
}
}

// Compile from within MATLAB via:
// mex simple_spmm_mex.cpp -R2018a -I{PATH_TO_SparseBLAS_INCLUDE}
// 'CXXFLAGS=$CFLAGS -std=c++20'
//
// Add '-g' to build in Debug mode if needed (activates asserts)
18 changes: 18 additions & 0 deletions examples/MATLAB_MEX/simple_spmm_mex.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
% simple_spmm_mex - Sparse matrix times dense matrix multiplication
% simple_smpm_mex.c - example in MATLAB External Interfaces
%
% Multiplies a (potentially scaled by a scalar alpha) sparse MxK matrix
% with a dense KxN matrix and outputs a dense MxN matrix:
% Y = A * X or Y = alpha * A * X
%
% The calling syntaxes are:
% Y = simple_smpv_mex(A, X)
% Y = simple_smpv_mex(A, X, alpha)
%
% The following restrictions apply:
% * A must be sparse
% * X must be dense
% * Number of columns in A and rows in X must match
% * All inputs must have the same data type and complexity
%
% This is a MEX-file for MATLAB.
134 changes: 134 additions & 0 deletions examples/MATLAB_MEX/simple_spmv_mex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Includes from SparseBLAS
#include <spblas/spblas.hpp>

// Includes for MEX
#include <matrix.h>
#include <mex.h>

// General includes
#include <complex> // Support complex inputs

template <typename T>
void spmvDriver(mxArray* mxY, const mxArray* mxA, const mxArray* mxX,
const mxArray* mxAlpha) {

// Gather dimensions
mwIndex m = mxGetM(mxA);
mwIndex n = mxGetN(mxA);

// Fill csc_view with:
// - T* values
// - mwIndex* colptr
// - mwIndex* rowind
// - {mwIndex m, mwIndex n} (shape)
// - mwIndex nnz
spblas::csc_view<const T, mwIndex> A(static_cast<const T*>(mxGetData(mxA)),
mxGetJc(mxA), mxGetIr(mxA), {m, n},
mxGetJc(mxA)[n]);
// Wrap x in a span of length n
std::span<const T> x(static_cast<const T*>(mxGetData(mxX)), n);

// Wrap output y in a span of length m
std::span<T> y(static_cast<T*>(mxGetData(mxY)), m);

// Store and apply scaling factor alpha, if provided and not empty
T alpha = T(1);
if (mxAlpha != nullptr && !mxIsEmpty(mxAlpha)) {
// We don't use mxGetScalar as it doesn't work for complex
alpha = *(static_cast<T*>(mxGetData(mxAlpha)));
}
auto alpha_A = spblas::scaled(alpha, A);

// y = (alpha * A) * x
spblas::multiply(alpha_A, x, y);
}

// y = (alpha *) A * x
// prhs[0] = A, prhs[1] = x (optional: prhs[2] = alpha)
// plhs[0] = y
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {

// General input checking
if (nrhs < 2 || nrhs > 3) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:WrongNumberOfInputs",
"Function needs 2 or 3 inputs.");
}
if (nlhs > 1) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:WrongNumberOfOutputs",
"Function returns only 1 output.");
}
if (mxGetClassID(prhs[0]) != mxGetClassID(prhs[1]) ||
((nrhs == 3) && mxGetClassID(prhs[1]) != mxGetClassID(prhs[2]))) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:ClassMismatch",
"All inputs must have matching type.");
}
if (!mxIsDouble(prhs[0]) && !mxIsSingle(prhs[0])) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:NonFloat",
"All inputs must be single or double.");
}
if (!mxIsDouble(prhs[0]) && !mxIsSingle(prhs[0])) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:NonFloat",
"All inputs must be single or double.");
}
if (!mxIsSparse(prhs[0])) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:FirstInputNotSparse",
"First input must be sparse.");
}
if (mxIsSparse(prhs[1])) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:SecondInputNotDense",
"Second input must be dense.");
}

// Reference SparseBLAS can handle inputs with mixed complexity,
// however, the vendor implementations need all inputs of the same
// complexity, hence, this example also insists on having matching
// complexity.
if (mxIsComplex(prhs[0]) != mxIsComplex(prhs[1]) ||
((nrhs == 3) && mxIsComplex(prhs[1]) != mxIsComplex(prhs[2]))) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:ComplexityMismatch",
"All inputs must have matching complexity.");
}

// Gather dimensions
mwIndex m = mxGetM(prhs[0]);
mwIndex n = mxGetN(prhs[0]);
// Check dimensions of second input
if ((mxGetM(prhs[1]) != n) && (mxGetN(prhs[1]) != 1)) {
mexErrMsgIdAndTxt("SparseBLAS_Mex:InnerDimWrong",
"Second input must be column vector of length n, "
"i.e., number of columns of first input.");
}

// Calculate in complex (we check for matching complexity above)
bool isCmplx = mxIsComplex(prhs[0]);

// Type dispatch for double or single, each as real or complex flavor
if (mxIsDouble(prhs[0])) {
if (isCmplx) {
plhs[0] = mxCreateNumericMatrix(m, 1, mxDOUBLE_CLASS, mxCOMPLEX);
spmvDriver<std::complex<double>>(plhs[0], prhs[0], prhs[1],
nrhs == 3 ? prhs[2] : nullptr);
} else {
plhs[0] = mxCreateNumericMatrix(m, 1, mxDOUBLE_CLASS, mxREAL);
spmvDriver<double>(plhs[0], prhs[0], prhs[1],
nrhs == 3 ? prhs[2] : nullptr);
}
} else {
mxAssert(mxIsSingle(prhs[0]), "Invalid data type");
if (isCmplx) {
plhs[0] = mxCreateNumericMatrix(m, 1, mxSINGLE_CLASS, mxCOMPLEX);
spmvDriver<std::complex<float>>(plhs[0], prhs[0], prhs[1],
nrhs == 3 ? prhs[2] : nullptr);
} else {
plhs[0] = mxCreateNumericMatrix(m, 1, mxSINGLE_CLASS, mxREAL);
spmvDriver<float>(plhs[0], prhs[0], prhs[1],
nrhs == 3 ? prhs[2] : nullptr);
}
}
}

// Compile from within MATLAB via:
// mex simple_spmv_mex.cpp -R2018a -I{PATH_TO_SparseBLAS_INCLUDE}
// 'CXXFLAGS=$CFLAGS -std=c++20'
//
// Add '-g' to build in Debug mode if needed (activates asserts)
18 changes: 18 additions & 0 deletions examples/MATLAB_MEX/simple_spmv_mex.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
% simple_spmv_mex - Sparse matrix times dense vector multiplication
% simple_smpv_mex.c - example in MATLAB External Interfaces
%
% Multiplies a (potentially scaled by a scalar alpha) sparse MxN matrix
% with a dense Nx1 column vector and outputs a dense Mx1 column vector:
% y = A * x or y = alpha * A * x
%
% The calling syntaxes are:
% y = simple_smpv_mex(A, x)
% y = simple_smpv_mex(A, x, alpha)
%
% The following restrictions apply:
% * A must be sparse
% * x must be dense column vector
% * Number of columns in A and rows in x must match
% * All inputs must have the same data type and complexity
%
% This is a MEX-file for MATLAB.
Loading
Loading