diff --git a/include/Type.hpp b/include/Type.hpp index 73bd92c0..72d25ba5 100644 --- a/include/Type.hpp +++ b/include/Type.hpp @@ -1,14 +1,15 @@ -#ifndef _H_TYPE_ -#define _H_TYPE_ +#ifndef INCLUDE_TYPE_H_ +#define INCLUDE_TYPE_H_ -#include #include +#include +#include +#include +#include +#include +#include #include -#include -#include -#include -#include -#include + #include "cytnx_error.hpp" #define MKL_Complex8 std::complex @@ -63,6 +64,63 @@ namespace cytnx { typedef std::complex cytnx_complex128; typedef bool cytnx_bool; + namespace internal { + template + struct is_complex_impl : std::false_type {}; + + template + struct is_complex_impl> : std::true_type {}; + + template + struct is_complex_floating_point_impl : std::false_type {}; + + template + struct is_complex_floating_point_impl> : std::is_floating_point {}; + + template + constexpr std::size_t index_in_tuple_helper() { + static_assert(I < std::tuple_size_v, "Type not found!"); + if constexpr(std::is_same_v>) { + return I; + } else { + return index_in_tuple_helper(); + } + } + + } // namespace internal + + template + using is_complex = internal::is_complex_impl>; + + template + using is_complex_floating_point = internal::is_complex_floating_point_impl>; + + // is_complex_v checks if a data type is of type std::complex + // usage: is_complex_v returns true or false for a data type T + template + constexpr bool is_complex_v = is_complex::value; + + // is_complex_floating_point_v is a template constant that is true if T is of type complex where + // U is a floating point type, and false otherwise. + template + constexpr bool is_complex_floating_point_v = is_complex_floating_point::value; + + // tuple_element_index returns the index of type T in the Tuple, or compile error if not found + template + struct tuple_element_index : std::integral_constant()> {}; + + template + constexpr int tuple_element_index_v = tuple_element_index::value; + + namespace internal { + // type_size returns the sizeof(T) for the supported types. This is the same as + // sizeof(T), except that size_type is 0. + template + constexpr int type_size = sizeof(T); + template <> + constexpr int type_size = 0; + } // namespace internal + /// @cond struct __type { enum __pybind_type { @@ -81,107 +139,119 @@ namespace cytnx { }; }; - struct Type_struct { - std::string name; - // char name[35]; - bool is_unsigned; - bool is_complex; - bool is_float; - bool is_int; - unsigned int typeSize; - }; + constexpr int N_Type = 12; + constexpr int N_fType = 5; + + // the list of supported types. The dtype() of an object is an index into this list. + // This **MUST** match the ordering of __type::__pybind_type + using Type_list = std::tuple< + void, + cytnx_complex128, + cytnx_complex64, + cytnx_double, + cytnx_float, + cytnx_int64, + cytnx_uint64, + cytnx_int32, + cytnx_uint32, + cytnx_int16, + cytnx_uint16, + cytnx_bool + >; + + // The friendly name of each type + template constexpr char* Type_names; + template <> constexpr const char* Type_names = "Void"; + template <> constexpr const char* Type_names = "Complex Double (Complex Float64)"; + template <> constexpr const char* Type_names = "Complex Float (Complex Float32)"; + template <> constexpr const char* Type_names = "Double (Float64)"; + template <> constexpr const char* Type_names = "Float (Float32)"; + template <> constexpr const char* Type_names = "Int64"; + template <> constexpr const char* Type_names = "Uint64"; + template <> constexpr const char* Type_names = "Int32"; + template <> constexpr const char* Type_names = "Uint32"; + template <> constexpr const char* Type_names = "Int16"; + template <> constexpr const char* Type_names = "Uint16"; + template <> constexpr const char* Type_names = "Bool"; + + struct Type_struct { + const char* name; // char* is OK here, it is only ever initialized from a string literal + bool is_unsigned; + bool is_complex; + bool is_float; + bool is_int; + unsigned int typeSize; + }; + + template + struct Type_struct_t { + static constexpr unsigned int cy_typeid = tuple_element_index_v; + static constexpr const char* name = Type_names; + static constexpr bool is_unsigned = std::is_unsigned_v; + static constexpr bool is_complex = is_complex_v; + static constexpr bool is_float = std::is_floating_point_v || is_complex_floating_point_v; + static constexpr bool is_int = std::is_integral_v && !std::is_same_v; + static constexpr std::size_t typeSize = internal::type_size; + + static constexpr Type_struct construct() { return {name, is_unsigned, is_complex, is_float, is_int, typeSize}; } + }; + + namespace internal { + template + constexpr auto make_type_array_helper(std::index_sequence) { + return std::array{Type_struct_t>::construct()...}; + } + template + constexpr auto make_type_array() { + return make_type_array_helper(std::make_index_sequence>()); + } + } // namespace internal - static const int N_Type = 12; - const int N_fType = 5; + // Typeinfos is a std::array for each type in Type_list + constexpr auto Typeinfos = internal::make_type_array(); + + template + constexpr unsigned int cy_typeid = tuple_element_index_v; class Type_class { private: public: enum : unsigned int { - Void, - ComplexDouble, - ComplexFloat, - Double, - Float, - Int64, - Uint64, - Int32, - Uint32, - Int16, - Uint16, - Bool + Void = cy_typeid, + ComplexDouble = cy_typeid, + ComplexFloat = cy_typeid, + Double = cy_typeid, + Float = cy_typeid, + Int64 = cy_typeid, + Uint64 = cy_typeid, + Int32 = cy_typeid, + Uint32 = cy_typeid, + Int16 = cy_typeid, + Uint16 = cy_typeid, + Bool = cy_typeid }; - // std::vector Typeinfos; - inline static Type_struct Typeinfos[N_Type]; - inline static bool inited = false; - Type_class &operator=(const Type_class &rhs) { - for (int i = 0; i < N_Type; i++) this->Typeinfos[i] = rhs.Typeinfos[i]; - return *this; - } - Type_class() { - // #ifdef DEBUG - // std::cout << "[DEBUG] Type constructor call. " << std::endl; - // #endif - if (!inited) { - Typeinfos[this->Void] = (Type_struct){"Void", true, false, false, false, 0}; - Typeinfos[this->ComplexDouble] = (Type_struct){ - "Complex Double (Complex Float64)", false, true, true, false, sizeof(cytnx_complex128)}; - Typeinfos[this->ComplexFloat] = (Type_struct){ - "Complex Float (Complex Float32)", false, true, true, false, sizeof(cytnx_complex64)}; - Typeinfos[this->Double] = - (Type_struct){"Double (Float64)", false, false, true, false, sizeof(cytnx_double)}; - Typeinfos[this->Float] = - (Type_struct){"Float (Float32)", false, false, true, false, sizeof(cytnx_float)}; - Typeinfos[this->Int64] = - (Type_struct){"Int64", false, false, false, true, sizeof(cytnx_int64)}; - Typeinfos[this->Uint64] = - (Type_struct){"Uint64", true, false, false, true, sizeof(cytnx_uint64)}; - Typeinfos[this->Int32] = - (Type_struct){"Int32", false, false, false, true, sizeof(cytnx_int32)}; - Typeinfos[this->Uint32] = - (Type_struct){"Uint32", true, false, false, true, sizeof(cytnx_uint32)}; - Typeinfos[this->Int16] = - (Type_struct){"Int16", false, false, false, true, sizeof(cytnx_int16)}; - Typeinfos[this->Uint16] = - (Type_struct){"Uint16", true, false, false, true, sizeof(cytnx_uint16)}; - Typeinfos[this->Bool] = - (Type_struct){"Bool", true, false, false, false, sizeof(cytnx_bool)}; - - inited = true; - } + static constexpr void check_type(unsigned int type_id) { + cytnx_error_msg(type_id >= N_Type, "[ERROR] invalid type_id %s", type_id); } - const std::string &getname(const unsigned int &type_id) const; - unsigned int c_typename_to_id(const std::string &c_name) const; - unsigned int typeSize(const unsigned int &type_id) const; - bool is_unsigned(const unsigned int &type_id) const; - bool is_complex(const unsigned int &type_id) const; - bool is_float(const unsigned int &type_id) const; - bool is_int(const unsigned int &type_id) const; - // int c_typeindex_to_id(const std::type_index &type_idx); + + static std::string getname(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].name; } // cannot be constexpr + static unsigned int c_typename_to_id(const std::string &c_name); // cannot be constexpr, defined in .cpp file + static constexpr unsigned int typeSize(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].typeSize; } + static constexpr bool is_unsigned(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].is_unsigned; } + static constexpr bool is_complex(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].is_complex; } + static constexpr bool is_float(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].is_float; } + static constexpr bool is_int(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].is_int; } + template - unsigned int cy_typeid(const T &rc) const { - cytnx_error_msg(true, "[ERROR] invalid type%s", "\n"); - return 0; - } - static unsigned int cy_typeid(const cytnx_complex128 &rc) { return Type_class::ComplexDouble; } - static unsigned int cy_typeid(const cytnx_complex64 &rc) { return Type_class::ComplexFloat; } - static unsigned int cy_typeid(const cytnx_double &rc) { return Type_class::Double; } - static unsigned int cy_typeid(const cytnx_float &rc) { return Type_class::Float; } - static unsigned int cy_typeid(const cytnx_uint64 &rc) { return Type_class::Uint64; } - static unsigned int cy_typeid(const cytnx_int64 &rc) { return Type_class::Int64; } - static unsigned int cy_typeid(const cytnx_uint32 &rc) { return Type_class::Uint32; } - static unsigned int cy_typeid(const cytnx_int32 &rc) { return Type_class::Int32; } - static unsigned int cy_typeid(const cytnx_uint16 &rc) { return Type_class::Uint16; } - static unsigned int cy_typeid(const cytnx_int16 &rc) { return Type_class::Int16; } - static unsigned int cy_typeid(const cytnx_bool &rc) { return Type_class::Bool; } - - unsigned int type_promote(const unsigned int &typeL, const unsigned int &typeR); - }; - /// @endcond + static constexpr unsigned int cy_typeid(const T &rc) { return Type_struct_t::cy_typeid; } - /// @cond - int type_promote(const int &typeL, const int &typeR); + template + static constexpr unsigned int cy_typeid_v = typeid(T{}); + + static unsigned int type_promote(unsigned int typeL, unsigned int typeR); + + }; // Type_class /// @endcond /** @@ -210,8 +280,8 @@ namespace cytnx { * Uint16 | undigned short integer with 16 bits * Bool | boolean type */ - extern Type_class Type; // move to cytnx.hpp and guarded - // static const Type_class Type = Type_class(); + + constexpr Type_class Type; extern int __blasINTsize__; @@ -219,4 +289,4 @@ namespace cytnx { } // namespace cytnx -#endif +#endif // INCLUDE_TYPE_H_ diff --git a/src/Type.cpp b/src/Type.cpp index c3f36a8e..a78b32ab 100644 --- a/src/Type.cpp +++ b/src/Type.cpp @@ -1,5 +1,6 @@ #include "Type.hpp" #include "cytnx_error.hpp" +#include #ifdef BACKEND_TORCH namespace cytnx { @@ -29,19 +30,18 @@ namespace cytnx { } namespace cytnx { - Type_class Type; - unsigned int Type_class::type_promote(const unsigned int &typeL, const unsigned int &typeR) { + unsigned int Type_class::type_promote(unsigned int typeL, unsigned int typeR) { if (typeL < typeR) { if (typeL == 0) return 0; - if (!this->is_unsigned(typeR) && this->is_unsigned(typeL)) { + if (!is_unsigned(typeR) && is_unsigned(typeL)) { return typeL - 1; } else { return typeL; } } else { if (typeR == 0) return 0; - if (!this->is_unsigned(typeL) && this->is_unsigned(typeR)) { + if (!is_unsigned(typeL) && is_unsigned(typeR)) { return typeR - 1; } else { return typeR; @@ -49,127 +49,35 @@ namespace cytnx { } } -} // namespace cytnx -/* -Type_class::Type_class(){ - this->c_typeid2_cy_typeid[type_index(typeid(void))] = this->Void; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_complex128))] = this->ComplexDouble; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_complex64 ))] = this->ComplexFloat ; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_double ))] = this->Double; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_float ))] = this->Float ; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_uint64 ))] = this->Uint64; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_int64 ))] = this->Int64 ; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_uint32 ))] = this->Uint32; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_int32 ))] = this->Int32 ; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_uint16 ))] = this->Uint16; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_int16 ))] = this->Int16 ; - this->c_typeid2_cy_typeid[type_index(typeid(cytnx_bool ))] = this->Bool ; - -} - -int c_typeindex_to_id(const std::type_index &type_idx){ - unordered_map::iterator it; - it = this->c_typeid2_cy_typeid.find(type_idx); - - if(it==this->c_typeid2_cy_typeid.end()){ - cytnx_error_msg(true,"[ERROR] invalid type!%s","\n"); + // Construct an array of typeid(T).name() for each type in Type_list. + // This is complicated by Type_list containing 'void', which means we can't use an ordinary lambda, but + // instead we need a metafunction and a template template parameter. + template + struct c_typename { + static const char* get() { + return typeid(T).name(); } + }; - return it->second; - -} -*/ - -// constexpr cytnx::Type_class::Type_class() { -// Typeinfos.resize(N_Type); -// //{name,unsigned,complex,float,int,typesize} -// Typeinfos[this->Void] = (Type_struct){std::string("Void"), true, false, false, false, 0}; -// Typeinfos[this->ComplexDouble] = (Type_struct){std::string("Complex Double (Complex Float64)"), -// false, -// true, -// true, -// false, -// sizeof(cytnx_complex128)}; -// Typeinfos[this->ComplexFloat] = (Type_struct){std::string("Complex Float (Complex Float32)"), -// false, -// true, -// true, -// false, -// sizeof(cytnx_complex64)}; -// Typeinfos[this->Double] = -// (Type_struct){std::string("Double (Float64)"), false, false, true, false, -// sizeof(cytnx_double)}; -// Typeinfos[this->Float] = -// (Type_struct){std::string("Float (Float32)"), false, false, true, false, sizeof(cytnx_float)}; -// Typeinfos[this->Int64] = -// (Type_struct){std::string("Int64"), false, false, false, true, sizeof(cytnx_int64)}; -// Typeinfos[this->Uint64] = -// (Type_struct){std::string("Uint64"), true, false, false, true, sizeof(cytnx_uint64)}; -// Typeinfos[this->Int32] = -// (Type_struct){std::string("Int32"), false, false, false, true, sizeof(cytnx_int32)}; -// Typeinfos[this->Uint32] = -// (Type_struct){std::string("Uint32"), true, false, false, true, sizeof(cytnx_uint32)}; -// Typeinfos[this->Int16] = -// (Type_struct){std::string("Int16"), false, false, false, true, sizeof(cytnx_int16)}; -// Typeinfos[this->Uint16] = -// (Type_struct){std::string("Uint16"), true, false, false, true, sizeof(cytnx_uint16)}; -// Typeinfos[this->Bool] = -// (Type_struct){std::string("Bool"), true, false, false, false, sizeof(cytnx_bool)}; -// } - -bool cytnx::Type_class::is_float(const unsigned int &type_id) const { - cytnx_error_msg(type_id >= N_Type, "[ERROR] invalid type_id%s", "\n"); - return Typeinfos[type_id].is_float; -} - -bool cytnx::Type_class::is_int(const unsigned int &type_id) const { - cytnx_error_msg(type_id >= N_Type, "[ERROR] invalid type_id%s", "\n"); - return Typeinfos[type_id].is_int; -} + template class Func, std::size_t... Indices> + auto make_type_array_from_func_helper(std::index_sequence) { + return std::array::get()), sizeof...(Indices)>{Func>::get()...}; + } -bool cytnx::Type_class::is_complex(const unsigned int &type_id) const { - cytnx_error_msg(type_id >= N_Type, "[ERROR] invalid type_id%s", "\n"); - return Typeinfos[type_id].is_complex; -} + template class Func> + auto make_type_array_from_func() { + return make_type_array_from_func_helper(std::make_index_sequence>()); + } -bool cytnx::Type_class::is_unsigned(const unsigned int &type_id) const { - cytnx_error_msg(type_id >= N_Type, "[ERROR] invalid type_id%s", "\n"); - return Typeinfos[type_id].is_unsigned; -} + unsigned int Type_class::c_typename_to_id(const std::string &c_name) { + static auto c_typenames = make_type_array_from_func(); -const std::string &cytnx::Type_class::getname(const unsigned int &type_id) const { - cytnx_error_msg(type_id >= N_Type, "[ERROR] invalid type_id%s", "\n"); - return Typeinfos[type_id].name; -} -unsigned int cytnx::Type_class::typeSize(const unsigned int &type_id) const { - cytnx_error_msg(type_id >= N_Type, "[ERROR] invalid type_id%s", "\n"); - return Typeinfos[type_id].typeSize; -} -unsigned int cytnx::Type_class::c_typename_to_id(const std::string &c_name) const { - if (c_name == typeid(cytnx_complex128).name()) { - return Type_class::ComplexDouble; - } else if (c_name == typeid(cytnx_complex64).name()) { - return Type_class::ComplexFloat; - } else if (c_name == typeid(cytnx_double).name()) { - return Type_class::Double; - } else if (c_name == typeid(cytnx_float).name()) { - return Type_class::Float; - } else if (c_name == typeid(cytnx_int64).name()) { - return Type_class::Int64; - } else if (c_name == typeid(cytnx_uint64).name()) { - return Type_class::Uint64; - } else if (c_name == typeid(cytnx_int32).name()) { - return Type_class::Int32; - } else if (c_name == typeid(cytnx_uint32).name()) { - return Type_class::Uint32; - } else if (c_name == typeid(cytnx_int16).name()) { - return Type_class::Int16; - } else if (c_name == typeid(cytnx_uint16).name()) { - return Type_class::Uint16; - } else if (c_name == typeid(cytnx_bool).name()) { - return Type_class::Bool; - } else { - cytnx_error_msg(1, "%s", "[ERROR] invalid type"); - return 0; + auto i = std::find(c_typenames.begin(), c_typenames.end(), c_name); + if (i == c_typenames.end()) { + cytnx_error_msg(true, "[ERROR] typename is not a cytnx type: %s", c_name.c_str()); + return 0; + } + return i - c_typenames.begin(); } -} + +} // namespace cytnx