-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Convert vector of basic type to numpy array instead of list
-Port numpy_proxy class from nda -Add macros.hpp file
- Loading branch information
Showing
5 changed files
with
355 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/******************************************************************************* | ||
* | ||
* TRIQS: a Toolbox for Research in Interacting Quantum Systems | ||
* | ||
* Copyright (C) 2011 by O. Parcollet | ||
* | ||
* TRIQS is free software: you can redistribute it and/or modify it under the | ||
* terms of the GNU General Public License as published by the Free Software | ||
* Foundation, either version 3 of the License, or (at your option) any later | ||
* version. | ||
* | ||
* TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY | ||
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS | ||
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more | ||
* details. | ||
* | ||
* You should have received a copy of the GNU General Public License along with | ||
* TRIQS. If not, see <http://www.gnu.org/licenses/>. | ||
* | ||
******************************************************************************/ | ||
#pragma once | ||
|
||
#include <iostream> | ||
|
||
#define AS_STRING(...) AS_STRING2(__VA_ARGS__) | ||
#define AS_STRING2(...) #__VA_ARGS__ | ||
|
||
#ifdef __clang__ | ||
#define REQUIRES(...) __attribute__((enable_if(__VA_ARGS__, AS_STRING(__VA_ARGS__)))) | ||
#elif __GNUC__ | ||
#define REQUIRES(...) requires(__VA_ARGS__) | ||
#endif | ||
|
||
#define PRINT(X) std::cerr << AS_STRING(X) << " = " << X << " at " << __FILE__ << ":" << __LINE__ << '\n' | ||
|
||
#define FORCEINLINE __inline__ __attribute__((always_inline)) | ||
|
||
#define EXPECTS(X) \ | ||
if (!(X)) { \ | ||
std::cerr << "Precondition " << AS_STRING(X) << " violated at " << __FILE__ << ":" << __LINE__ << "\n"; \ | ||
std::terminate(); \ | ||
} | ||
#define ASSERT(X) \ | ||
if (!(X)) { \ | ||
std::cerr << "Assertion " << AS_STRING(X) << " violated at " << __FILE__ << ":" << __LINE__ << "\n"; \ | ||
std::terminate(); \ | ||
} | ||
#define ENSURES(X) \ | ||
if (!(X)) { \ | ||
std::cerr << "Postcondition " << AS_STRING(X) << " violated at " << __FILE__ << ":" << __LINE__ << "\n"; \ | ||
std::terminate(); \ | ||
} | ||
|
||
#define EXPECTS_WITH_MESSAGE(X, ...) \ | ||
if (!(X)) { \ | ||
std::cerr << "Precondition " << AS_STRING(X) << " violated at " << __FILE__ << ":" << __LINE__ << "\n"; \ | ||
std::cerr << "Error message : " << __VA_ARGS__ << std::endl; \ | ||
std::terminate(); \ | ||
} | ||
#define ASSERT_WITH_MESSAGE(X, ...) \ | ||
if (!(X)) { \ | ||
std::cerr << "Assertion " << AS_STRING(X) << " violated at " << __FILE__ << ":" << __LINE__ << "\n"; \ | ||
std::cerr << "Error message : " << __VA_ARGS__ << std::endl; \ | ||
std::terminate(); \ | ||
} | ||
#define ENSURES_WITH_MESSAGE(X, ...) \ | ||
if (!(X)) { \ | ||
std::cerr << "Postcondition " << AS_STRING(X) << " violated at " << __FILE__ << ":" << __LINE__ << "\n"; \ | ||
std::cerr << "Error message : " << __VA_ARGS__ << std::endl; \ | ||
std::terminate(); \ | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#include "numpy_proxy.hpp" | ||
|
||
namespace cpp2py { | ||
|
||
// Make a new view_info | ||
PyObject *numpy_proxy::to_python() { | ||
|
||
// Apparently we can not get rid of this | ||
_import_array(); | ||
|
||
#ifdef PYTHON_NUMPY_VERSION_LT_17 | ||
int flags = NPY_BEHAVED & ~NPY_OWNDATA; | ||
#else | ||
int flags = NPY_ARRAY_BEHAVED & ~NPY_ARRAY_OWNDATA; | ||
#endif | ||
// make the array read only | ||
if (is_const) flags &= ~NPY_ARRAY_WRITEABLE; | ||
PyObject *result = | ||
PyArray_NewFromDescr(&PyArray_Type, PyArray_DescrFromType(element_type), rank, extents.data(), strides.data(), data, flags, NULL); | ||
if (not result) return nullptr; // the Python error is set | ||
|
||
if (!PyArray_Check(result)) { | ||
PyErr_SetString(PyExc_RuntimeError, "The python object is not a numpy array"); | ||
return nullptr; | ||
} | ||
|
||
PyArrayObject *arr = (PyArrayObject *)(result); | ||
#ifdef PYTHON_NUMPY_VERSION_LT_17 | ||
arr->base = base; | ||
assert(arr->flags == (arr->flags & ~NPY_OWNDATA)); | ||
#else | ||
int r = PyArray_SetBaseObject(arr, base); | ||
//EXPECTS(r == 0); | ||
//EXPECTS(PyArray_FLAGS(arr) == (PyArray_FLAGS(arr) & ~NPY_ARRAY_OWNDATA)); | ||
#endif | ||
base = nullptr; // ref is stolen by the new object | ||
|
||
return result; | ||
} | ||
|
||
// ---------------------------------------------------------- | ||
|
||
// Extract a view_info from python | ||
numpy_proxy make_numpy_proxy(PyObject *obj) { | ||
|
||
// Apparently we can not get rid of this | ||
_import_array(); | ||
|
||
if (obj == NULL) return {}; | ||
if (not PyArray_Check(obj)) return {}; | ||
|
||
numpy_proxy result; | ||
|
||
// extract strides and lengths | ||
PyArrayObject *arr = (PyArrayObject *)(obj); | ||
|
||
#ifdef PYTHON_NUMPY_VERSION_LT_17 | ||
result.rank = arr->nd; | ||
#else | ||
result.rank = PyArray_NDIM(arr); | ||
#endif | ||
|
||
result.element_type = PyArray_TYPE(arr); | ||
result.extents.resize(result.rank); | ||
result.strides.resize(result.rank); | ||
result.data = PyArray_DATA(arr); | ||
// base is ignored, stays at nullptr | ||
|
||
#ifdef PYTHON_NUMPY_VERSION_LT_17 | ||
for (long i = 0; i < result.rank; ++i) { | ||
result.extents[i] = size_t(arr->dimensions[i]); | ||
result.strides[i] = std::ptrdiff_t(arr->strides[i]); | ||
} | ||
#else | ||
for (size_t i = 0; i < result.rank; ++i) { | ||
result.extents[i] = size_t(PyArray_DIMS(arr)[i]); | ||
result.strides[i] = std::ptrdiff_t(PyArray_STRIDES(arr)[i]); | ||
} | ||
#endif | ||
|
||
//PRINT(result.rank); | ||
//PRINT(result.element_type); | ||
//PRINT(result.data); | ||
|
||
return result; | ||
} | ||
|
||
// ---------------------------------------------------------- | ||
|
||
PyObject *make_numpy_copy(PyObject *obj, int rank, long element_type) { | ||
|
||
if (obj == nullptr) return nullptr; | ||
|
||
// From obj, we ask the numpy library to make a numpy, and of the correct type. | ||
// This handles automatically the cases where : | ||
// - we have list, or list of list/tuple | ||
// - the numpy type is not the one we want. | ||
// - adjust the dimension if needed | ||
// If obj is an array : | ||
// - if Order is same, don't change it | ||
// - else impose it (may provoque a copy). | ||
// if obj is not array : | ||
// - Order = FortranOrder or SameOrder - > Fortran order otherwise C | ||
|
||
int flags = 0; //(ForceCast ? NPY_FORCECAST : 0) ;// do NOT force a copy | (make_copy ? NPY_ENSURECOPY : 0); | ||
//if (!(PyArray_Check(obj) )) | ||
//flags |= ( IndexMapType::traversal_order == indexmaps::mem_layout::c_order(rank) ? NPY_C_CONTIGUOUS : NPY_F_CONTIGUOUS); //impose mem order | ||
#ifdef PYTHON_NUMPY_VERSION_LT_17 | ||
flags |= (NPY_C_CONTIGUOUS); //impose mem order | ||
flags |= (NPY_ENSURECOPY); | ||
#else | ||
flags |= (NPY_ARRAY_C_CONTIGUOUS); // impose mem order | ||
flags |= (NPY_ARRAY_ENSURECOPY); | ||
#endif | ||
return PyArray_FromAny(obj, PyArray_DescrFromType(element_type), rank, rank, flags, NULL); // new ref | ||
} | ||
|
||
} // namespace nda::python |
Oops, something went wrong.