-
Notifications
You must be signed in to change notification settings - Fork 88
/
pybind_matrix.h
51 lines (37 loc) · 1.21 KB
/
pybind_matrix.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "matrix.h"
namespace py = pybind11;
// type caster: Matrix <-> NumPy-array
namespace pybind11 { namespace detail {
template <typename T> struct type_caster<Matrix<T>>
{
public:
PYBIND11_TYPE_CASTER(Matrix<T>, _("Matrix<T>"));
// Conversion part 1 (Python -> C++)
bool load(py::handle src, bool convert)
{
if (!convert && !py::array_t<T>::check_(src))
return false;
auto buf = py::array_t<T, py::array::c_style | py::array::forcecast>::ensure(src);
if (!buf)
return false;
auto dims = buf.ndim();
if (dims < 1 || dims > 3)
return false;
std::vector<size_t> shape(buf.ndim());
for ( int i=0 ; i<buf.ndim() ; i++ )
shape[i] = buf.shape()[i];
value = Matrix<T>(shape,buf.data());
return true;
}
//Conversion part 2 (C++ -> Python)
static py::handle cast(const Matrix<T>& src,
py::return_value_policy policy, py::handle parent)
{
py::array a(std::move(src.shape()), std::move(src.strides(true)), src.data() );
return a.release();
}
};
}}