-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathnumpy_mod.cpp
More file actions
108 lines (83 loc) · 3.49 KB
/
numpy_mod.cpp
File metadata and controls
108 lines (83 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <iostream>
namespace py = pybind11;
double dot_product(py::array_t<double> a, py::array_t<double> b)
{
py::buffer_info a_info = a.request();
py::buffer_info b_info = b.request();
//std::cout << a_info.format << " " << b_info.format << std::endl;
//std::cout << a_info.ndim << " " << b_info.ndim << std::endl;
// Check if both arguments are vectors of the same length
if(a_info.ndim != 1)
throw std::runtime_error("a is not a vector!");
if(b_info.ndim != 1)
throw std::runtime_error("b is not a vector!");
if(a_info.shape[0] != b_info.shape[0])
throw std::runtime_error("a and b are vectors of different lengths");
double dot_prod = 0.0;
size_t len = a_info.shape[0];
const double * a_data = a.data();
const double * b_data = b.data();
size_t a_stride = a_info.strides[0] / sizeof(double);
size_t b_stride = b_info.strides[0] / sizeof(double);
for(size_t i = 0; i < len; i++)
dot_prod += a_data[i*a_stride] * b_data[i*b_stride];
return dot_prod;
}
py::array_t<double> dgemm(double alpha, py::array_t<double> a, py::array_t<double> b)
{
py::buffer_info a_info = a.request();
py::buffer_info b_info = b.request();
// Check if both arguments are matrices, and that the number
// of columns of 'a' is the same as the number of rows of 'b'
if(a_info.ndim != 2)
throw std::runtime_error("a is not a matrix!");
if(b_info.ndim != 2)
throw std::runtime_error("b is not a matrix!");
if(a_info.shape[1] != b_info.shape[0])
throw std::runtime_error("incompatible matrix dimensions");
size_t c_nrows = a_info.shape[0];
size_t c_ncols = b_info.shape[1];
size_t n_k = a_info.shape[1];
std::vector<double> c_data(c_nrows * c_ncols);
const double * a_data = a.data();
const double * b_data = b.data();
// Data may not be stored in strict row-major order
// so we use the strides from the buffer info
// The strides are stored as number of bytes, so convert that to number
// of doubles
// We are specifying that C is stored in row-major order
const size_t a_stride_row = a_info.strides[0] / sizeof(double);
const size_t a_stride_col = a_info.strides[1] / sizeof(double);
const size_t b_stride_row = b_info.strides[0] / sizeof(double);
const size_t b_stride_col = b_info.strides[1] / sizeof(double);
// perform the gemm
for(size_t i = 0; i < c_nrows; i++)
for(size_t j = 0; j < c_ncols; j++)
{
double value = 0.0;
for(size_t k = 0; k < n_k; k++)
value += a_data[i*a_stride_row + k*a_stride_col]
* b_data[k*b_stride_row + j*b_stride_col];
c_data[i*c_ncols+j] = value * alpha;
}
// create a new buffer
py::buffer_info c_info = {
c_data.data(), // pointer to the data we just allocated
sizeof(double), // size of a single element
py::format_descriptor<double>::format(), // type of data held
2, // number of dimensions (2 = matrix)
{ c_nrows, c_ncols }, // shape of the matrix
{ sizeof(double) * c_ncols, // Stride between two rows (in bytes)
sizeof(double) } // Stride between two cols (in bytes)
};
return py::array(c_info);
}
PYBIND11_PLUGIN(numpy_mod)
{
py::module m("numpy_mod", "Ben's basic numpy module");
m.def("dot_product", dot_product);
m.def("dgemm", dgemm);
return m.ptr();
}