Skip to content

元编程

chlict edited this page Apr 28, 2020 · 14 revisions

元编程:通过程序控制(源)代码

  • 简化代码编写
  • 提高性能

c++的元函数

template <typename T>
struct Fun {
    using type = std::vector<T>; /* f(T) */
};

T是输入,Fun<T>::type是输出

Fun<int>::type vec_int = {1, 2, 3};
  • 元函数的特点:

    • 无副作用
    • 编译期执行(通过类型推导执行)
  • 可以充当元函数的元素:类型(包括模板类)、常数、函数(函数对象,也是类型的一种)

  • 元函数的例子

template <class T, T v>
struct integral_constant
{   
    static const T      value = v;
    typedef T           value_type;
    typedef integral_constant type;
};

static_assert(integral_constant<int, 1>::value == 1);
static_assert(std::is_same_v<integral_constant<int, 1>::value_type, int>);

后面还有对integral_constant使用的例子。

c++元编程的基本操作

赋值

    template <bool B>                                   // C++14
    using bool_constant = integral_constant<bool, B>;   // C++14

    typedef bool_constant<true> true_type;              // C++14
    typedef bool_constant<false> false_type;            // C++14

运算

template <class T, T v>
struct integral_constant
{
    static const T      value = v;
    typedef T           value_type;
    typedef integral_constant type;
};

template <typename T, T v>
using IC = integral_constant<T, v>; 

template <typename IC1, typename IC2> 
struct plus {
    using value_type = decltype(IC1::value + IC2::value);
    using type = IC<value_type, IC1::value + IC2::value>;
};

auto foo() {
    using one = IC<int, 1>; 
    using two = IC<int, 2>; 
    using three = plus<one, two>::type; // IC<int, 3>

    static_assert(three::value == 3); 
    static_assert(std::is_same_v<three, IC<int, 3> >); 
}

如果不习惯模板类,也可以写模板函数,借助类型推导,可以写出更像一般函数风格的元函数

template <typename IC1, typename IC2> 
auto plus2(IC1 const& ic1, IC2 const& ic2) {
    using value_type = decltype(IC1::value + IC2::value);
    using type = IC<value_type, IC1::value + IC2::value>;
    return type();
}

auto foo2() {
    auto one = IC<int, 1>();
    auto two = IC<int, 2>();
    auto three = plus2(one, two);

    static_assert(three.value == 3); 
    static_assert(std::is_same_v<decltype(three), IC<int, 3> >); 
}

借助c++11引入的literal operator,可以将literal解析成编译期常量,写法更精炼:

#include <boost/hana.hpp>
using namespace boost::hana::literals;

static_assert(1_c + 2_c == 3_c);

其中的1_c实际是hana::integral_constant<long long, 1>的一个值(也是该类型唯一的值)。+==都已经被重载。

控制流

  • 顺序

    template <typename T>
    struct Fun {
        using t1 = std::remove_reference_t<T>;
        using t2 = std::vector<t1>;
    }

    t1, t2不能颠倒顺序。

  • 分支

    方式一:特例化

    template <class _Tp> struct __libcpp_is_integral                     : public false_type {};
    template <>          struct __libcpp_is_integral<bool>               : public true_type {};
    template <>          struct __libcpp_is_integral<char>               : public true_type {};
    
    template <>          struct __libcpp_is_integral<short>              : public true_type {};
    template <>          struct __libcpp_is_integral<int>                : public true_type {};
    template <>          struct __libcpp_is_integral<long>               : public true_type {};
    ...

    另外的例子:vector<int>vector<bool>可以是完全不同的实现。

    方式二:使用c++17引入的if constexpr

    template <typename T>
    auto zero(T const& v) {
        if constexpr (std::is_integral<T>::value) {
            return 0;
        } else {
            return 0.0;
        }
    }

    注意编译器会对返回值类型进行推导,即便两个分支的返回值类型不同,不妨碍编译。两个分支相当于两个版本,每个版本中相当于只有一个return语句。

  • 循环

    遍历variadic template的变长参数包,如何做到?

    方式一:递归

    #include <typeinfo>
    #include <type_traits>
    #include <iostream>
    #include <cassert>
    
    template <typename ...T>
    struct Fun1 {
        static auto print() {
            assert(false && "should never be called");
        }
    };
    
    template <typename T>
    struct Fun1<T> {
        static auto print() {
            std::cout << typeid(T).name() << std::endl;
        }
    };
    
    template <typename Head, typename ...Tails>
    struct Fun1<Head, Tails...> {
        static auto print() {
            std::cout << typeid(Head).name() << std::endl;
            Fun1<Tails...>::print();
        }
    };
    
    int main() {
        Fun1<char, short, int, long>::print();
    }

    方式二:借用boost::hana库

    #include <typeinfo>
    #include <iostream>
    #include <boost/hana.hpp>
    
    namespace hana = boost::hana;
    
    template <typename ...T>
    auto Fun2() {
        hana::for_each(hana::tuple_t<T...>,
            [](auto t) {
                using type = typename decltype(t)::type;
                std::cout << typeid(type).name() << std::endl;
            }
        );
    }
    
    int main() {
        Fun2<char, short, int, long>();
    }

    注意:不管哪种方法,生成的代码中没有递归或循环,只有4条打印的语句——编译期计算

  • 函数调用

    元函数里可以调用其它元函数; 注意函数(lambda)也可以作为模板参数或者返回值,从而可以实现函数式编程风格

    #include <typeinfo>
    #include <cstdio>
    
    template <typename T, typename Fn>
    auto foo(T const& t, Fn const& fn) {
        return fn(t);
    }
    
    int main() {
        auto fn = [](auto i) { printf("type: %s\n", typeid(i).name()); };
        foo(1, fn);
        foo(1.0, fn);
    }

数据结构及算法

  • tuple

    • std::tuple<T...>
    • boost::hana::tuple<T...>
    • 异质列表,可以存储任意类型、任意数量的元素。
  • 围绕tuple有各种操作:front(), back(), at(index), ...

  • class/struct仍然适用

  • algorithm

    • for_each(tuple, fn)
    • transform(tuple, fn)
    • fold(tuple, init, fn)
    • zip(tuple1, tuple2)
    • filter(tuple, pred)
    • ...

总结

  • 算术系统(数值、运算)
  • 赋值
  • 控制流
    • 顺序
    • 分支
    • 循环
    • 函数调用
  • 数据结构
  • 算法

构成了一套完整的编程体系。

boost::hana库

  • 提供了一套元编程的基础库
  • boost::hana的编程思想(https://www.boost.org/doc/libs/1_61_0/libs/hana/doc/html/index.html)
    • Compile-time number
    • Compile-time arithmetic
    • Compile-time branching
    • Type computations - Types as objects
    • Generalities on containers
    • Generalities on algorithms
  • 既适用于编译期计算,也适用于运行时计算

编程实践 - Tensor

template<typename Shape, typename Layout, typename LayoutProvider>
struct TensorFormat {
    Shape shape_;
    Layout layout_;

    constexpr TensorFormat(Shape const &shape, Layout const &layout) :
            shape_(shape), layout_(layout) {}
    ...
};

template<typename ElemType, typename Format, typename Space, typename Addr>
struct Tensor {
    const Format format_;
    const Addr addr_;

    constexpr Tensor(ElemType const &type, Format const &format, Space const &space, Addr const &addr) :
            format_(format), addr_(addr) {}
    ...
};

void Tester1() {
    auto format1 = make_format(Dim2(2_c, 4_c), RowMajorLayout());
    auto tensor1 = Tensor(float(), format1, MemSpace::GM(), 0x1000);
    auto tensor2 = Tensor(float(), format1, MemSpace::GM(), 0x2000);
    tadd1(tensor1, tensor2);
}
  • tadd1函数里会取出tensor1tensor2(类型)里存取的shape,layout等信息,在编译期计算出两者相加需要的intrinsic参数,调用intrinsic完成计算。
  • 兼顾了高效和抽象(tadd1在不同的平台上可以生成不同的代码)

表达式模板及boost::yap库介绍

  • 构建自定义的表达式,并对表达式加以解释

    • 表达式被lazy_evaluation
    • 方便自定义语法,构建EDSL
  • 表达式示例

int foo(int x) {
    return x + 1;
}

int main ()
{
    auto foo_ = yap::make_terminal(foo);
    auto expr = foo_(1) + 2; // 不同于auto x = foo(1) + 2, x的值将会是4,而expr的值是一个yap::expression

    yap::print(std::cout, expr);

    auto result = yap::evaluate(expr);  // 此时才对expr求值,即编译器会看到还原出来的foo(1) + 2
    std::cout << "result = " << result << std::endl;
}

得到的结果如下

expr<+>
    expr<()>
        term<int (*)(int)>[=1] &
        term<int>[=1]
    term<int>[=2]
result = 4

可以看到,yap::expression打印出来类似一个AST,其中有操作符‘+’和'()'(函数调用),还有操作数(即终结符):'int (*)(int)'类型的函数foo,以及int类型的数值1和2。

本例中expr打印后立即被evaluate,其实也可以不进行evaluate,可以对这个表达式进行自由的变换。

  • yap::expression的原理

其实很简单,就是记录了操作符和操作数列表的一个struct

template <boost::yap::expr_kind Kind, typename Tuple>
struct minimal_expr
{
    static const boost::yap::expr_kind kind = Kind;

    Tuple elements;
};

其中expr_kind的定义:

 21     enum class expr_kind {
 22         expr_ref =
 23             0, ///< A (possibly \c const) reference to another expression.
 24 
 25         terminal = 1, ///< A terminal expression.
 26 
 27         // unary
 28         unary_plus = 2,  ///< \c +
 29         negate = 3,      ///< \c -
 30         dereference = 4, ///< \c *
            ...
 38 
 39         // binary
 40         shift_left = 12,         ///< \c <<
 41         shift_right = 13,        ///< \c >>
 42         multiplies = 14,         ///< \c *
 43         divides = 15,            ///< \c /
            ...

 76         // n-ary
 77         call = 45 ///< \c ()
 78     };

Yap examples

  • Pipable algorithms (https://www.boost.org/doc/libs/1_71_0/doc/html/boost_yap/manual.html#boost_yap.manual.examples.pipable_algorithms)

    std::vector<int> v1 = {0, 2, 2, 7, 1, 3, 8};
    std::vector<int> const v2 = sort(v1) | unique;
    assert(v2 == std::vector<int>({0, 1, 2, 3, 7, 8}));

    注意这里的sort不是std::sort,是yap封装后的sort(只是一个terminal)

  • Phoenix-style let (https://www.boost.org/doc/libs/1_71_0/doc/html/boost_yap/manual.html#boost_yap.manual.examples.boost_phoenix_style__let___)

        {
            auto expr = let(_a = 2)[_a + 1];
            assert(boost::yap::evaluate(expr) == 3);
        }
    
        {
            auto expr = let(_a = 123, _b = 456)[_a + _b];
            assert(boost::yap::evaluate(expr) == 123 + 456);
        }
    
        // Prints "Hello, World" due to let()'s scoping rules.
        {
            boost::yap::evaluate(
                let(_a = 1_p, _b = 2_p)
                [
                    // _a here is an int: 1
    
                    let(_a = 3_p) // hides the outer _a
                    [
                        cout << _a << _b // prints "Hello, World"
                    ]
                ],
                1, " World", "Hello,"
            );
        }
    
        std::cout << "\n";
    

    其中trick是:let是个元函数,()里是let的参数,接收一系列assign形式的表达式。let返回一个对象,它重载了[]运算符,operator 接受yap表达式做参数,并对这些表达式做yap::transform或yap::evaluate。

    // Takes N > 0 expressions of the form 'placeholder = expr', and returns an
    // object with an overloaded operator[]().
    template<typename Expr, typename... Exprs>
    auto let(Expr && expr, Exprs &&... exprs)
    {
        return let_impl(
            boost::hana::make_map(),
            std::forward<Expr>(expr),
            std::forward<Exprs>(exprs)...);
    }

编程实践 - Tensor表达式

TEST(TestExprCompiler, Test4) {
    auto format1 = make_format(Dim2(2, 4), RowMajorLayout());
    auto tensor1 = Tensor(float(), format1, MemSpace::GM(), 0x10);
    auto tensor2 = Tensor(float(), format1, MemSpace::GM(), 0x20);
    auto tensor3 = Tensor(float(), format1, MemSpace::GM(), 0x30);

    auto add_mul = [](auto &&... args) {
        using namespace boost::yap::literals;
        auto expr = 1_p + 2_p * 3_p;
        auto kernel = ECompiler(expr).compile(args...);
        launch(kernel);
    };

    add_mul(tensor1, tensor2, tensor3);
}

其中compile()函数里会对tensor1 + tensor2 * tensor3这个表达式做变换和code_gen(),最终会生成乘法和加法的指令(调用intrinsic),效果等同于:

// 为中间结果申请临时变量
auto temp1 = Tensor(float(), format1, MemSpace::GM(), 0x40);
auto temp2 = Tensor(float(), format1, MemSpace::GM(), 0x50);
tmul(temp1, tensor2, tensor3);
tadd(temp2, tensor1, temp1);

注意,是编译期的变换。