-
Notifications
You must be signed in to change notification settings - Fork 0
元编程
- 简化代码编写
- 提高性能
template <typename T>
struct Fun {
using type = std::vector<T>; /* f(T) */
};
T
是输入,Fun<T>::type
是输出
Fun<int>::type vec_int = {1, 2, 3};
-
元函数的特点:
- 无副作用
- 编译期执行(通过类型推导执行)
-
可以充当元函数的元素:类型(包括模板类)、常数、函数(函数对象,也是类型的一种)
-
元函数的例子
template <class T, T v>
struct integral_constant
{
static const T value = v;
typedef T value_type;
typedef integral_constant type;
};
static_assert(integral_constant<int, 1>::value == 1);
static_assert(std::is_same_v<integral_constant<int, 1>::value_type, int>);
后面还有对integral_constant
使用的例子。
template <bool B> // C++14
using bool_constant = integral_constant<bool, B>; // C++14
typedef bool_constant<true> true_type; // C++14
typedef bool_constant<false> false_type; // C++14
template <class T, T v>
struct integral_constant
{
static const T value = v;
typedef T value_type;
typedef integral_constant type;
};
template <typename T, T v>
using IC = integral_constant<T, v>;
template <typename IC1, typename IC2>
struct plus {
using value_type = decltype(IC1::value + IC2::value);
using type = IC<value_type, IC1::value + IC2::value>;
};
auto foo() {
using one = IC<int, 1>;
using two = IC<int, 2>;
using three = plus<one, two>::type; // IC<int, 3>
static_assert(three::value == 3);
static_assert(std::is_same_v<three, IC<int, 3> >);
}
如果不习惯模板类,也可以写模板函数,借助类型推导,可以写出更像一般函数风格的元函数
template <typename IC1, typename IC2>
auto plus2(IC1 const& ic1, IC2 const& ic2) {
using value_type = decltype(IC1::value + IC2::value);
using type = IC<value_type, IC1::value + IC2::value>;
return type();
}
auto foo2() {
auto one = IC<int, 1>();
auto two = IC<int, 2>();
auto three = plus2(one, two);
static_assert(three.value == 3);
static_assert(std::is_same_v<decltype(three), IC<int, 3> >);
}
借助c++11引入的literal operator
,可以将literal解析成编译期常量,写法更精炼:
#include <boost/hana.hpp>
using namespace boost::hana::literals;
static_assert(1_c + 2_c == 3_c);
其中的1_c
实际是hana::integral_constant<long long, 1>
的一个值(也是该类型唯一的值)。+
和==
都已经被重载。
-
顺序
template <typename T> struct Fun { using t1 = std::remove_reference_t<T>; using t2 = std::vector<t1>; }
t1
,t2
不能颠倒顺序。 -
分支
方式一:特例化
template <class _Tp> struct __libcpp_is_integral : public false_type {}; template <> struct __libcpp_is_integral<bool> : public true_type {}; template <> struct __libcpp_is_integral<char> : public true_type {}; template <> struct __libcpp_is_integral<short> : public true_type {}; template <> struct __libcpp_is_integral<int> : public true_type {}; template <> struct __libcpp_is_integral<long> : public true_type {}; ...
另外的例子:
vector<int>
和vector<bool>
可以是完全不同的实现。方式二:使用c++17引入的
if constexpr
template <typename T> auto zero(T const& v) { if constexpr (std::is_integral<T>::value) { return 0; } else { return 0.0; } }
注意编译器会对返回值类型进行推导,即便两个分支的返回值类型不同,不妨碍编译。两个分支相当于两个版本,每个版本中相当于只有一个
return
语句。 -
循环
遍历variadic template的变长参数包,如何做到?
方式一:递归
#include <typeinfo> #include <type_traits> #include <iostream> #include <cassert> template <typename ...T> struct Fun1 { static auto print() { assert(false && "should never be called"); } }; template <typename T> struct Fun1<T> { static auto print() { std::cout << typeid(T).name() << std::endl; } }; template <typename Head, typename ...Tails> struct Fun1<Head, Tails...> { static auto print() { std::cout << typeid(Head).name() << std::endl; Fun1<Tails...>::print(); } }; int main() { Fun1<char, short, int, long>::print(); }
方式二:借用boost::hana库
#include <typeinfo> #include <iostream> #include <boost/hana.hpp> namespace hana = boost::hana; template <typename ...T> auto Fun2() { hana::for_each(hana::tuple_t<T...>, [](auto t) { using type = typename decltype(t)::type; std::cout << typeid(type).name() << std::endl; } ); } int main() { Fun2<char, short, int, long>(); }
注意:不管哪种方法,生成的代码中没有递归或循环,只有4条打印的语句——编译期计算
-
函数调用
元函数里可以调用其它元函数; 注意函数(lambda)也可以作为模板参数或者返回值,从而可以实现函数式编程风格
#include <typeinfo> #include <cstdio> template <typename T, typename Fn> auto foo(T const& t, Fn const& fn) { return fn(t); } int main() { auto fn = [](auto i) { printf("type: %s\n", typeid(i).name()); }; foo(1, fn); foo(1.0, fn); }
-
tuple
- std::tuple<T...>
- boost::hana::tuple<T...>
- 异质列表,可以存储任意类型、任意数量的元素。
-
围绕tuple有各种操作:front(), back(), at(index), ...
-
class/struct仍然适用
-
algorithm
- for_each(tuple, fn)
- transform(tuple, fn)
- fold(tuple, init, fn)
- zip(tuple1, tuple2)
- filter(tuple, pred)
- ...
- 算术系统(数值、运算)
- 赋值
- 控制流
- 顺序
- 分支
- 循环
- 函数调用
- 数据结构
- 算法
构成了一套完整的编程体系。
- 提供了一套元编程的基础库
- boost::hana的编程思想(https://www.boost.org/doc/libs/1_61_0/libs/hana/doc/html/index.html)
- Compile-time number
- Compile-time arithmetic
- Compile-time branching
- Type computations - Types as objects
- Generalities on containers
- Generalities on algorithms
- 既适用于编译期计算,也适用于运行时计算
template<typename Shape, typename Layout, typename LayoutProvider>
struct TensorFormat {
Shape shape_;
Layout layout_;
constexpr TensorFormat(Shape const &shape, Layout const &layout) :
shape_(shape), layout_(layout) {}
...
};
template<typename ElemType, typename Format, typename Space, typename Addr>
struct Tensor {
const Format format_;
const Addr addr_;
constexpr Tensor(ElemType const &type, Format const &format, Space const &space, Addr const &addr) :
format_(format), addr_(addr) {}
...
};
void Tester1() {
auto format1 = make_format(Dim2(2_c, 4_c), RowMajorLayout());
auto tensor1 = Tensor(float(), format1, MemSpace::GM(), 0x1000);
auto tensor2 = Tensor(float(), format1, MemSpace::GM(), 0x2000);
tadd1(tensor1, tensor2);
}
-
tadd1
函数里会取出tensor1
和tensor2
(类型)里存取的shape,layout等信息,在编译期计算出两者相加需要的intrinsic参数,调用intrinsic完成计算。 - 兼顾了高效和抽象(tadd1在不同的平台上可以生成不同的代码)
-
构建自定义的表达式,并对表达式加以解释
- 表达式被lazy_evaluation
- 方便自定义语法,构建EDSL
-
表达式示例
int foo(int x) {
return x + 1;
}
int main ()
{
auto foo_ = yap::make_terminal(foo);
auto expr = foo_(1) + 2; // 不同于auto x = foo(1) + 2, x的值将会是4,而expr的值是一个yap::expression
yap::print(std::cout, expr);
auto result = yap::evaluate(expr); // 此时才对expr求值,即编译器会看到还原出来的foo(1) + 2
std::cout << "result = " << result << std::endl;
}
得到的结果如下
expr<+>
expr<()>
term<int (*)(int)>[=1] &
term<int>[=1]
term<int>[=2]
result = 4
可以看到,yap::expression打印出来类似一个AST,其中有操作符‘+’和'()'(函数调用),还有操作数(即终结符):'int (*)(int)'类型的函数foo,以及int类型的数值1和2。
本例中expr打印后立即被evaluate,其实也可以不进行evaluate,可以对这个表达式进行自由的变换。
- yap::expression的原理
其实很简单,就是记录了操作符和操作数列表的一个struct
template <boost::yap::expr_kind Kind, typename Tuple>
struct minimal_expr
{
static const boost::yap::expr_kind kind = Kind;
Tuple elements;
};
其中expr_kind的定义:
21 enum class expr_kind {
22 expr_ref =
23 0, ///< A (possibly \c const) reference to another expression.
24
25 terminal = 1, ///< A terminal expression.
26
27 // unary
28 unary_plus = 2, ///< \c +
29 negate = 3, ///< \c -
30 dereference = 4, ///< \c *
...
38
39 // binary
40 shift_left = 12, ///< \c <<
41 shift_right = 13, ///< \c >>
42 multiplies = 14, ///< \c *
43 divides = 15, ///< \c /
...
76 // n-ary
77 call = 45 ///< \c ()
78 };
-
Pipable algorithms (https://www.boost.org/doc/libs/1_71_0/doc/html/boost_yap/manual.html#boost_yap.manual.examples.pipable_algorithms)
std::vector<int> v1 = {0, 2, 2, 7, 1, 3, 8}; std::vector<int> const v2 = sort(v1) | unique; assert(v2 == std::vector<int>({0, 1, 2, 3, 7, 8}));
注意这里的sort不是std::sort,是yap封装后的sort(只是一个terminal)
-
Phoenix-style let (https://www.boost.org/doc/libs/1_71_0/doc/html/boost_yap/manual.html#boost_yap.manual.examples.boost_phoenix_style__let___)
{ auto expr = let(_a = 2)[_a + 1]; assert(boost::yap::evaluate(expr) == 3); } { auto expr = let(_a = 123, _b = 456)[_a + _b]; assert(boost::yap::evaluate(expr) == 123 + 456); } // Prints "Hello, World" due to let()'s scoping rules. { boost::yap::evaluate( let(_a = 1_p, _b = 2_p) [ // _a here is an int: 1 let(_a = 3_p) // hides the outer _a [ cout << _a << _b // prints "Hello, World" ] ], 1, " World", "Hello," ); } std::cout << "\n";
其中trick是:let是个元函数,()里是let的参数,接收一系列assign形式的表达式。let返回一个对象,它重载了[]运算符,operator 接受yap表达式做参数,并对这些表达式做yap::transform或yap::evaluate。
// Takes N > 0 expressions of the form 'placeholder = expr', and returns an // object with an overloaded operator[](). template<typename Expr, typename... Exprs> auto let(Expr && expr, Exprs &&... exprs) { return let_impl( boost::hana::make_map(), std::forward<Expr>(expr), std::forward<Exprs>(exprs)...); }
TEST(TestExprCompiler, Test4) {
auto format1 = make_format(Dim2(2, 4), RowMajorLayout());
auto tensor1 = Tensor(float(), format1, MemSpace::GM(), 0x10);
auto tensor2 = Tensor(float(), format1, MemSpace::GM(), 0x20);
auto tensor3 = Tensor(float(), format1, MemSpace::GM(), 0x30);
auto add_mul = [](auto &&... args) {
using namespace boost::yap::literals;
auto expr = 1_p + 2_p * 3_p;
auto kernel = ECompiler(expr).compile(args...);
launch(kernel);
};
add_mul(tensor1, tensor2, tensor3);
}
其中compile()函数里会对tensor1 + tensor2 * tensor3这个表达式做变换和code_gen(),最终会生成乘法和加法的指令(调用intrinsic),效果等同于:
// 为中间结果申请临时变量
auto temp1 = Tensor(float(), format1, MemSpace::GM(), 0x40);
auto temp2 = Tensor(float(), format1, MemSpace::GM(), 0x50);
tmul(temp1, tensor2, tensor3);
tadd(temp2, tensor1, temp1);
注意,是编译期的变换。