Skip to content

Commit

Permalink
Refactor Type.hpp
Browse files Browse the repository at this point in the history
Refactor Type.hpp to automatically deduce properties of the types at compile time.

Type_class now contains only static members, and all except Type_class::getname() and Type_class::c_typename_to_id() are constexpr.

The global variable Type still exists, but it is constexpr.  Calls to Type.<something> are now equivalent to Type_class::<something>.
But I do not suggest using Type_class directly, but instead consider removing the Type global variable and renaming Type_class to Type.
It could also be a namespace rather than a class, which could be useful in some circumstances.

I wasn't able to make Type_class::getname() constexpr because it returns a std::string.  It *could* return a const char*,
since all of the strings that it can return are string literals, but this causes a lot of compilation failures because of code like:

    cytnx_error_msg(dtype != Type.Float,
                    "[ERROR] type mismatch. try to get <float> type from raw data of type %s",
                    Type.getname(dtype).c_str());

where it is calling getname().c_str(), which obviously fails if getname() is returning a const char*.

The proper fix is to improve cytnx_error_msg() so that it can print values in a type-safe way.  After that is done, Type_class::getname()
could be changed to constexpr and return a const char*.

I wasn't able to make Type_class::c_typename_to_id() constexpr because typeid(T).name() is not constexpr.

A possible extension would be to add a wrapper around typeid(T).name() that returns Type.getname() if T is a cytnx type, with
a fallback if it isn't a cytnx type.  It could also use some compiler-specific functions if available, for example both gcc
and clang have abi::__cxa_demangle() function.  MSVC users are out of luck ;)
  • Loading branch information
ianmccul committed Nov 21, 2024
1 parent fa4596e commit 4ad3cbe
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 223 deletions.
272 changes: 171 additions & 101 deletions include/Type.hpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#ifndef _H_TYPE_
#define _H_TYPE_
#ifndef INCLUDE_TYPE_H_
#define INCLUDE_TYPE_H_

#include <string>
#include <complex>
#include <cstdint>
#include <string>
#include <type_traits>
#include <tuple>
#include <array>
#include <utility>
#include <vector>
#include <stdint.h>
#include <climits>
#include <typeinfo>
#include <unordered_map>
#include <typeindex>

#include "cytnx_error.hpp"

#define MKL_Complex8 std::complex<float>
Expand Down Expand Up @@ -63,6 +64,63 @@ namespace cytnx {
typedef std::complex<double> cytnx_complex128;
typedef bool cytnx_bool;

namespace internal {
template <class>
struct is_complex_impl : std::false_type {};

template <class T>
struct is_complex_impl<std::complex<T>> : std::true_type {};

template <typename>
struct is_complex_floating_point_impl : std::false_type {};

template <typename T>
struct is_complex_floating_point_impl<std::complex<T>> : std::is_floating_point<T> {};

template <std::size_t I, typename T, typename Tuple>
constexpr std::size_t index_in_tuple_helper() {
static_assert(I < std::tuple_size_v<Tuple>, "Type not found!");
if constexpr(std::is_same_v<T, std::tuple_element_t<I, Tuple>>) {
return I;
} else {
return index_in_tuple_helper<I+1, T, Tuple>();
}
}

} // namespace internal

template <typename T>
using is_complex = internal::is_complex_impl<std::remove_cv_t<T>>;

template <typename T>
using is_complex_floating_point = internal::is_complex_floating_point_impl<std::remove_cv_t<T>>;

// is_complex_v checks if a data type is of type std::complex
// usage: is_complex_v<T> returns true or false for a data type T
template <typename T>
constexpr bool is_complex_v = is_complex<T>::value;

// is_complex_floating_point_v<T> is a template constant that is true if T is of type complex<U> where
// U is a floating point type, and false otherwise.
template <typename T>
constexpr bool is_complex_floating_point_v = is_complex_floating_point<T>::value;

// tuple_element_index<T, Tuple> returns the index of type T in the Tuple, or compile error if not found
template <typename T, typename Tuple>
struct tuple_element_index : std::integral_constant<std::size_t, internal::index_in_tuple_helper<0, T, Tuple>()> {};

template <typename T, typename Tuple>
constexpr int tuple_element_index_v = tuple_element_index<T, Tuple>::value;

namespace internal {
// type_size returns the sizeof(T) for the supported types. This is the same as
// sizeof(T), except that size_type<void> is 0.
template <typename T>
constexpr int type_size = sizeof(T);
template <>
constexpr int type_size<void> = 0;
} // namespace internal

/// @cond
struct __type {
enum __pybind_type {
Expand All @@ -81,107 +139,119 @@ namespace cytnx {
};
};

struct Type_struct {
std::string name;
// char name[35];
bool is_unsigned;
bool is_complex;
bool is_float;
bool is_int;
unsigned int typeSize;
};
constexpr int N_Type = 12;
constexpr int N_fType = 5;

// the list of supported types. The dtype() of an object is an index into this list.
// This **MUST** match the ordering of __type::__pybind_type
using Type_list = std::tuple<
void,
cytnx_complex128,
cytnx_complex64,
cytnx_double,
cytnx_float,
cytnx_int64,
cytnx_uint64,
cytnx_int32,
cytnx_uint32,
cytnx_int16,
cytnx_uint16,
cytnx_bool
>;

// The friendly name of each type
template <typename T> constexpr char* Type_names;
template <> constexpr const char* Type_names<void> = "Void";
template <> constexpr const char* Type_names<cytnx_complex128> = "Complex Double (Complex Float64)";
template <> constexpr const char* Type_names<cytnx_complex64> = "Complex Float (Complex Float32)";
template <> constexpr const char* Type_names<cytnx_double> = "Double (Float64)";
template <> constexpr const char* Type_names<cytnx_float> = "Float (Float32)";
template <> constexpr const char* Type_names<cytnx_int64> = "Int64";
template <> constexpr const char* Type_names<cytnx_uint64> = "Uint64";
template <> constexpr const char* Type_names<cytnx_int32> = "Int32";
template <> constexpr const char* Type_names<cytnx_uint32> = "Uint32";
template <> constexpr const char* Type_names<cytnx_int16> = "Int16";
template <> constexpr const char* Type_names<cytnx_uint16> = "Uint16";
template <> constexpr const char* Type_names<cytnx_bool> = "Bool";

struct Type_struct {
const char* name; // char* is OK here, it is only ever initialized from a string literal
bool is_unsigned;
bool is_complex;
bool is_float;
bool is_int;
unsigned int typeSize;
};

template <typename T>
struct Type_struct_t {
static constexpr unsigned int cy_typeid = tuple_element_index_v<T, Type_list>;
static constexpr const char* name = Type_names<T>;
static constexpr bool is_unsigned = std::is_unsigned_v<T>;
static constexpr bool is_complex = is_complex_v<T>;
static constexpr bool is_float = std::is_floating_point_v<T> || is_complex_floating_point_v<T>;
static constexpr bool is_int = std::is_integral_v<T> && !std::is_same_v<T, bool>;
static constexpr std::size_t typeSize = internal::type_size<T>;

static constexpr Type_struct construct() { return {name, is_unsigned, is_complex, is_float, is_int, typeSize}; }
};

namespace internal {
template <typename Tuple, std::size_t... Indices>
constexpr auto make_type_array_helper(std::index_sequence<Indices...>) {
return std::array<Type_struct, sizeof...(Indices)>{Type_struct_t<std::tuple_element_t<Indices, Tuple>>::construct()...};
}
template <typename Tuple>
constexpr auto make_type_array() {
return make_type_array_helper<Tuple>(std::make_index_sequence<std::tuple_size_v<Tuple>>());
}
} // namespace internal

static const int N_Type = 12;
const int N_fType = 5;
// Typeinfos is a std::array<Type_struct> for each type in Type_list
constexpr auto Typeinfos = internal::make_type_array<Type_list>();

template <typename T>
constexpr unsigned int cy_typeid = tuple_element_index_v<T, Type_list>;

class Type_class {
private:
public:
enum : unsigned int {
Void,
ComplexDouble,
ComplexFloat,
Double,
Float,
Int64,
Uint64,
Int32,
Uint32,
Int16,
Uint16,
Bool
Void = cy_typeid<void>,
ComplexDouble = cy_typeid<cytnx_complex128>,
ComplexFloat = cy_typeid<cytnx_complex64>,
Double = cy_typeid<cytnx_double>,
Float = cy_typeid<cytnx_float>,
Int64 = cy_typeid<cytnx_int64>,
Uint64 = cy_typeid<cytnx_uint64>,
Int32 = cy_typeid<cytnx_int32>,
Uint32 = cy_typeid<cytnx_uint32>,
Int16 = cy_typeid<cytnx_int16>,
Uint16 = cy_typeid<cytnx_uint16>,
Bool = cy_typeid<cytnx_bool>
};
// std::vector<Type_struct> Typeinfos;
inline static Type_struct Typeinfos[N_Type];
inline static bool inited = false;
Type_class &operator=(const Type_class &rhs) {
for (int i = 0; i < N_Type; i++) this->Typeinfos[i] = rhs.Typeinfos[i];
return *this;
}

Type_class() {
// #ifdef DEBUG
// std::cout << "[DEBUG] Type constructor call. " << std::endl;
// #endif
if (!inited) {
Typeinfos[this->Void] = (Type_struct){"Void", true, false, false, false, 0};
Typeinfos[this->ComplexDouble] = (Type_struct){
"Complex Double (Complex Float64)", false, true, true, false, sizeof(cytnx_complex128)};
Typeinfos[this->ComplexFloat] = (Type_struct){
"Complex Float (Complex Float32)", false, true, true, false, sizeof(cytnx_complex64)};
Typeinfos[this->Double] =
(Type_struct){"Double (Float64)", false, false, true, false, sizeof(cytnx_double)};
Typeinfos[this->Float] =
(Type_struct){"Float (Float32)", false, false, true, false, sizeof(cytnx_float)};
Typeinfos[this->Int64] =
(Type_struct){"Int64", false, false, false, true, sizeof(cytnx_int64)};
Typeinfos[this->Uint64] =
(Type_struct){"Uint64", true, false, false, true, sizeof(cytnx_uint64)};
Typeinfos[this->Int32] =
(Type_struct){"Int32", false, false, false, true, sizeof(cytnx_int32)};
Typeinfos[this->Uint32] =
(Type_struct){"Uint32", true, false, false, true, sizeof(cytnx_uint32)};
Typeinfos[this->Int16] =
(Type_struct){"Int16", false, false, false, true, sizeof(cytnx_int16)};
Typeinfos[this->Uint16] =
(Type_struct){"Uint16", true, false, false, true, sizeof(cytnx_uint16)};
Typeinfos[this->Bool] =
(Type_struct){"Bool", true, false, false, false, sizeof(cytnx_bool)};

inited = true;
}
static constexpr void check_type(unsigned int type_id) {
cytnx_error_msg(type_id >= N_Type, "[ERROR] invalid type_id %s", type_id);
}
const std::string &getname(const unsigned int &type_id) const;
unsigned int c_typename_to_id(const std::string &c_name) const;
unsigned int typeSize(const unsigned int &type_id) const;
bool is_unsigned(const unsigned int &type_id) const;
bool is_complex(const unsigned int &type_id) const;
bool is_float(const unsigned int &type_id) const;
bool is_int(const unsigned int &type_id) const;
// int c_typeindex_to_id(const std::type_index &type_idx);

static std::string getname(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].name; } // cannot be constexpr
static unsigned int c_typename_to_id(const std::string &c_name); // cannot be constexpr, defined in .cpp file
static constexpr unsigned int typeSize(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].typeSize; }
static constexpr bool is_unsigned(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].is_unsigned; }
static constexpr bool is_complex(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].is_complex; }
static constexpr bool is_float(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].is_float; }
static constexpr bool is_int(unsigned int type_id) { check_type(type_id); return Typeinfos[type_id].is_int; }

template <class T>
unsigned int cy_typeid(const T &rc) const {
cytnx_error_msg(true, "[ERROR] invalid type%s", "\n");
return 0;
}
static unsigned int cy_typeid(const cytnx_complex128 &rc) { return Type_class::ComplexDouble; }
static unsigned int cy_typeid(const cytnx_complex64 &rc) { return Type_class::ComplexFloat; }
static unsigned int cy_typeid(const cytnx_double &rc) { return Type_class::Double; }
static unsigned int cy_typeid(const cytnx_float &rc) { return Type_class::Float; }
static unsigned int cy_typeid(const cytnx_uint64 &rc) { return Type_class::Uint64; }
static unsigned int cy_typeid(const cytnx_int64 &rc) { return Type_class::Int64; }
static unsigned int cy_typeid(const cytnx_uint32 &rc) { return Type_class::Uint32; }
static unsigned int cy_typeid(const cytnx_int32 &rc) { return Type_class::Int32; }
static unsigned int cy_typeid(const cytnx_uint16 &rc) { return Type_class::Uint16; }
static unsigned int cy_typeid(const cytnx_int16 &rc) { return Type_class::Int16; }
static unsigned int cy_typeid(const cytnx_bool &rc) { return Type_class::Bool; }

unsigned int type_promote(const unsigned int &typeL, const unsigned int &typeR);
};
/// @endcond
static constexpr unsigned int cy_typeid(const T &rc) { return Type_struct_t<T>::cy_typeid; }

/// @cond
int type_promote(const int &typeL, const int &typeR);
template <typename T>
static constexpr unsigned int cy_typeid_v = typeid(T{});

static unsigned int type_promote(unsigned int typeL, unsigned int typeR);

}; // Type_class
/// @endcond

/**
Expand Down Expand Up @@ -210,13 +280,13 @@ namespace cytnx {
* Uint16 | undigned short integer with 16 bits
* Bool | boolean type
*/
extern Type_class Type; // move to cytnx.hpp and guarded
// static const Type_class Type = Type_class();

constexpr Type_class Type;

extern int __blasINTsize__;

extern bool User_debug;

} // namespace cytnx

#endif
#endif // INCLUDE_TYPE_H_
Loading

0 comments on commit 4ad3cbe

Please sign in to comment.