Skip to content

Commit

Permalink
Python: Complex Support
Browse files Browse the repository at this point in the history
  • Loading branch information
ax3l committed Sep 1, 2020
1 parent b1fe35c commit 6713bb3
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
30 changes: 28 additions & 2 deletions include/openPMD/binding/python/Numpy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ namespace openPMD
return Datatype::ULONG;
else if( dt.is(pybind11::dtype("ulonglong")) )
return Datatype::ULONGLONG;
else if( dt.is(pybind11::dtype("clongdouble")) )
return Datatype::CLONG_DOUBLE;
else if( dt.is(pybind11::dtype("cdouble")) )
return Datatype::CDOUBLE;
else if( dt.is(pybind11::dtype("csingle")) )
return Datatype::CFLOAT;
else if( dt.is(pybind11::dtype("longdouble")) )
return Datatype::LONG_DOUBLE;
else if( dt.is(pybind11::dtype("double")) )
Expand All @@ -65,8 +71,10 @@ namespace openPMD
return Datatype::FLOAT;
else if( dt.is(pybind11::dtype("bool")) )
return Datatype::BOOL;
else
else {
pybind11::print(dt);
throw std::runtime_error("Datatype '...' not known in 'dtype_from_numpy'!"); // _s.format(dt)
}
}

/** Return openPMD::Datatype from py::buffer_info::format
Expand All @@ -79,7 +87,7 @@ namespace openPMD
// refs:
// https://docs.scipy.org/doc/numpy-1.15.0/reference/arrays.interface.html
// https://docs.python.org/3/library/struct.html#format-characters
// std::cout << " scalar type '" << buf.format << "'" << std::endl;
// std::cout << " scalar type '" << fmt << "'" << std::endl;
// typestring: encoding + type + number of bytes
if( fmt.find("?") != std::string::npos )
return DT::BOOL;
Expand All @@ -103,6 +111,12 @@ namespace openPMD
return DT::ULONG;
else if( fmt.find("Q") != std::string::npos )
return DT::ULONGLONG;
else if( fmt.find("Zf") != std::string::npos )
return DT::CFLOAT;
else if( fmt.find("Zd") != std::string::npos )
return DT::CDOUBLE;
else if( fmt.find("Zg") != std::string::npos )
return DT::CLONG_DOUBLE;
else if( fmt.find("f") != std::string::npos )
return DT::FLOAT;
else if( fmt.find("d") != std::string::npos )
Expand Down Expand Up @@ -179,6 +193,18 @@ namespace openPMD
case DT::VEC_LONG_DOUBLE:
return pybind11::dtype("longdouble");
break;
case DT::CFLOAT:
case DT::VEC_CFLOAT:
return pybind11::dtype("csingle");
break;
case DT::CDOUBLE:
case DT::VEC_CDOUBLE:
return pybind11::dtype("cdouble");
break;
case DT::CLONG_DOUBLE:
case DT::VEC_CLONG_DOUBLE:
return pybind11::dtype("clongdouble");
break;
case DT::BOOL:
return pybind11::dtype("bool"); // also "?"
break;
Expand Down
33 changes: 31 additions & 2 deletions src/binding/python/Attributable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
#include "openPMD/auxiliary/Variant.hpp"
#include "openPMD/binding/python/Numpy.hpp"

#include <array>
#include <complex>
#include <string>
#include <vector>
#include <array>


// std::variant
Expand Down Expand Up @@ -113,6 +114,15 @@ bool setAttributeFromBufferInfo(
case DT::LONG_DOUBLE:
return attr.setAttribute( key, *static_cast<long double*>(buf.ptr) );
break;
case DT::CFLOAT:
return attr.setAttribute( key, *static_cast<std::complex<float>*>(buf.ptr) );
break;
case DT::CDOUBLE:
return attr.setAttribute( key, *static_cast<std::complex<double>*>(buf.ptr) );
break;
case DT::CLONG_DOUBLE:
return attr.setAttribute( key, *static_cast<std::complex<long double>*>(buf.ptr) );
break;
default:
throw std::runtime_error("set_attribute: Unknown "
"Python type '" + buf.format +
Expand Down Expand Up @@ -155,6 +165,7 @@ bool setAttributeFromBufferInfo(
static_cast<bool*>(buf.ptr) + buf.size
) );
else */
// std::cout << "+++++++++++ BUFFER: " << buf.format << std::endl;
if( buf.format.find("b") != std::string::npos )
return attr.setAttribute( key,
std::vector<char>(
Expand Down Expand Up @@ -215,6 +226,24 @@ bool setAttributeFromBufferInfo(
static_cast<unsigned long long*>(buf.ptr),
static_cast<unsigned long long*>(buf.ptr) + buf.size
) );
else if( buf.format.find("Zf") != std::string::npos )
return attr.setAttribute( key,
std::vector<std::complex<float>>(
static_cast<std::complex<float>*>(buf.ptr),
static_cast<std::complex<float>*>(buf.ptr) + buf.size
) );
else if( buf.format.find("Zd") != std::string::npos )
return attr.setAttribute( key,
std::vector<std::complex<double>>(
static_cast<std::complex<double>*>(buf.ptr),
static_cast<std::complex<double>*>(buf.ptr) + buf.size
) );
else if( buf.format.find("Zg") != std::string::npos )
return attr.setAttribute( key,
std::vector<std::complex<long double>>(
static_cast<std::complex<long double>*>(buf.ptr),
static_cast<std::complex<long double>*>(buf.ptr) + buf.size
) );
else if( buf.format.find("f") != std::string::npos )
return attr.setAttribute( key,
std::vector<float>(
Expand Down Expand Up @@ -307,7 +336,7 @@ void init_Attributable(py::module &m) {
// .def("set_attribute", &Attributable::setAttribute< std::vector< char > >)
.def("set_attribute", &Attributable::setAttribute< std::vector< unsigned char > >)
.def("set_attribute", &Attributable::setAttribute< std::vector< long > >)
.def("set_attribute", &Attributable::setAttribute< std::vector< double > >)
.def("set_attribute", &Attributable::setAttribute< std::vector< double > >) // TODO: this implicitly casts list of complex
// probably affected by bug https://github.com/pybind/pybind11/issues/1258
.def("set_attribute", []( Attributable & attr, std::string const& key, std::vector< std::string > const& value ) {
return attr.setAttribute( key, value );
Expand Down
16 changes: 16 additions & 0 deletions src/binding/python/RecordComponent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "openPMD/auxiliary/ShareRaw.hpp"
#include "openPMD/binding/python/Numpy.hpp"

#include <complex>
#include <string>
#include <algorithm>
#include <tuple>
Expand Down Expand Up @@ -395,6 +396,12 @@ load_chunk(RecordComponent & r, Offset const & offset, Extent const & extent, st
r.loadChunk<double>(shareRaw((double*) a.mutable_data()), offset, extent);
else if( r.getDatatype() == Datatype::FLOAT )
r.loadChunk<float>(shareRaw((float*) a.mutable_data()), offset, extent);
else if( r.getDatatype() == Datatype::CLONG_DOUBLE )
r.loadChunk<std::complex<long double>>(shareRaw((std::complex<long double>*) a.mutable_data()), offset, extent);
else if( r.getDatatype() == Datatype::CDOUBLE )
r.loadChunk<std::complex<double>>(shareRaw((std::complex<double>*) a.mutable_data()), offset, extent);
else if( r.getDatatype() == Datatype::CFLOAT )
r.loadChunk<std::complex<float>>(shareRaw((std::complex<float>*) a.mutable_data()), offset, extent);
else if( r.getDatatype() == Datatype::BOOL )
r.loadChunk<bool>(shareRaw((bool*) a.mutable_data()), offset, extent);
else
Expand Down Expand Up @@ -508,6 +515,15 @@ void init_RecordComponent(py::module &m) {
case DT::LONG_DOUBLE:
return rc.makeConstant( *static_cast<long double*>(buf.ptr) );
break;
case DT::CFLOAT:
return rc.makeConstant( *static_cast<std::complex<float>*>(buf.ptr) );
break;
case DT::CDOUBLE:
return rc.makeConstant( *static_cast<std::complex<double>*>(buf.ptr) );
break;
case DT::CLONG_DOUBLE:
return rc.makeConstant( *static_cast<std::complex<long double>*>(buf.ptr) );
break;
default:
throw std::runtime_error("make_constant: "
"Unknown Datatype!");
Expand Down

0 comments on commit 6713bb3

Please sign in to comment.