From 219e68d3d7a795661a6949c114479dc60bd1ce5f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 1 Jul 2021 13:36:52 -0400 Subject: [PATCH] [Arith] Inverse affine map (#402) * [Arith] Inverse affine map * Update iter_affine_map.h * Update iter_affine_map.h * Update iter_affine_map.py * Topology order visit * doc * fix * address comments --- include/tvm/arith/iter_affine_map.h | 21 +++ python/tvm/arith/__init__.py | 3 +- python/tvm/arith/iter_affine_map.py | 27 ++++ src/arith/iter_affine_map.cc | 142 ++++++++++++++++++ .../unittest/test_arith_iter_affine_map.py | 62 ++++++++ 5 files changed, 254 insertions(+), 1 deletion(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 9a3f08487d..203f4224ac 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -283,6 +283,27 @@ Array DetectIterMap(const Array& indices, const Map InverseAffineIterMap(const Array& iter_map, + const Array outputs); + /*! * \brief Use IterVarMap detector to rewrite and simplify the bindings * diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index a4cdb9839b..64318eee3c 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -22,4 +22,5 @@ from .pattern import detect_linear_equation, detect_clip_bound from .int_solver import solve_linear_equations, solve_linear_inequalities from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr -from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide +from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide, \ + inverse_affine_iter_map diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 1d020288d0..3dbbfa7628 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -173,3 +173,30 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi Empty array if no match can be found. """ return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective) + + +def inverse_affine_iter_map(iter_map, outputs): + """ Apply the inverse of the affine transformation to the outputs. + Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions + in reverse topology order and applies the inverse of the affine transformation until it reaches + the input. The affine iter map is required to be bijective. + + For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1], + the affine transformation specified by `iter_map` will be applied to `outputs` and the result + will be {l0: ((output_0*16) + output_1)}. + + See also :any:`detect_iter_map`. + + Parameters + ---------- + iter_map : List[IterSumExpr] + The bijective affine iter map. + outputs : List[PrimExpr] + The outputs of the affine transformation. + + Returns + ------- + results : Map[Var, PrimExpr] + The map from the input to the transformed result. + """ + return _ffi_api.InverseAffineIterMap(iter_map, outputs) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index a82e19b8f6..bea078e57b 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1403,5 +1403,147 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide") return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); }); +class InverseAffineIterMapTransformer { + public: + explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : analyzer_(analyzer) {} + + Map operator()(const Array& iter_map, + const Array& outputs) { + ICHECK(iter_map.size() == outputs.size()); + std::vector post_dfs_order = ReverseTopologyOrder(iter_map); + + // initialize back propagation accumulator + for (const IterMapExprNode* node: post_dfs_order) { + backprop_.Set(GetRef(node), Integer(0)); + } + for (size_t i = 0; i < iter_map.size(); i++) { + backprop_.Set(iter_map[i], outputs[i]); + } + + // run back propagation + for (const IterMapExprNode* node: post_dfs_order) { + if (node->IsInstance()) { + Visit_(Downcast(GetRef(node))); + } else { + ICHECK(node->IsInstance()); + Visit_(Downcast(GetRef(node))); + } + } + return std::move(inverse_); + } + + private: + void Visit_(const IterSumExpr& iter_map_expr) { + PrimExpr input = backprop_.at(iter_map_expr) - iter_map_expr->base; + + // Case 1: Propagate to the input node directly when the sum expression has only one components + if (iter_map_expr->args.size() == 1) { + const auto& source = iter_map_expr->args[0]; + backprop_.Set(source, backprop_.at(source) + input); + return; + } + + // Case 2: If the sum expression has multiple components, match the fuse pattern and then split + // the sum expression for each components. + // For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], fusing i1 and i2 + // we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we need to split the + // propagated value to get the corresponding components of i1 and i2, which are + // floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively. + Array splits = MatchFusePattern(iter_map_expr); + ICHECK(!splits.empty()); + + for (const IterSplitExpr& split : splits) { + backprop_.Set(split, backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent)); + } + } + + + std::vector ReverseTopologyOrder(const Array& iter_map) { + std::vector post_dfs_order; + std::unordered_map visited; + + std::function fvisit = [&](const IterMapExpr& expr) { + if (visited[expr]) { + return; + } + visited[expr] = true; + if (const auto* sum_expr = expr.as()) { + for (const IterSplitExpr& child : sum_expr->args) { + fvisit(child); + } + } else { + const auto* split_expr = expr.as(); + ICHECK(split_expr); + if (const auto* source = split_expr->source->source.as()) { + fvisit(GetRef(source)); + } + } + post_dfs_order.push_back(expr.get()); + }; + for (const IterSumExpr& expr : iter_map) { + fvisit(expr); + } + std::reverse(post_dfs_order.begin(), post_dfs_order.end()); + return post_dfs_order; + } + + void Visit_(const IterSplitExpr& iter_map_expr) { + PrimExpr input = backprop_.at(iter_map_expr) * iter_map_expr->lower_factor; + const IterMark& source = iter_map_expr->source; + if (source->source.as()) { + IterSumExpr source_expr = Downcast(source->source); + backprop_.Set(source_expr, backprop_.at(source_expr) + input); + } else { + Var source_var = Downcast(source->source); + if (inverse_.count(source_var)) { + inverse_.Set(source_var, inverse_.at(source_var) + input); + } else { + inverse_.Set(source_var, input); + } + } + } + + Array MatchFusePattern(const IterSumExpr sum_expr) { + IntImm base_scale(nullptr); + size_t base_index = 0; + for (size_t i = 0; i < sum_expr->args.size(); ++i) { + if (const auto* op = sum_expr->args[i]->scale.as()) { + if (!base_scale.defined() || op->value < base_scale->value) { + base_scale = GetRef(op); + base_index = i; + } + } + } + ICHECK(base_scale.defined()); + std::vector iters; + std::vector visited(sum_expr->args.size(), false); + PrimExpr expected_scale = base_scale; + for (size_t i = 0; i < sum_expr->args.size(); i++) { + size_t j = i == 0 ? base_index : 0; + for (; j < sum_expr->args.size(); ++j) { + if (!visited[j] && analyzer_->CanProveEqual(sum_expr->args[j]->scale, expected_scale)) + break; + } + ICHECK(j != sum_expr->args.size()); + visited[j] = true; + iters.push_back(sum_expr->args[j]); + expected_scale *= sum_expr->args[j]->extent; + } + return iters; + } + + Analyzer* analyzer_; + Map backprop_; // the accumulator of backpropgation + Map inverse_; // the result of inverse transformation +}; + +Map InverseAffineIterMap(const Array& iter_map, + const Array outputs) { + Analyzer analyzer; + return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs); +} + +TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); + } // namespace arith } // namespace tvm diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 835140a6a5..bbf38b4bf3 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -18,6 +18,7 @@ import tvm.testing from tvm import te from tvm.tir import floormod, floordiv +from tvm.topi.image.resize import resize_bilinear def convert_division(divisions): @@ -716,6 +717,66 @@ def test_normalize_iter_map_to_expr(): tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) +def test_inverse_affine_iter_map(): + analyzer = tvm.arith.Analyzer() + l0 = create_iter("l0", 64) + l1 = create_iter("l1", 64) + l2 = create_iter("l2", 64) + + # simple case + l0_0, l0_1 = isplit(l0, 16) + l1_0, l1_1 = isplit(l1, 4) + l0_1_l1_1_fused = ifuse([l0_1, l1_1]) + + iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], + var_dom([l0, l1])) + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + print(res) + assert len(res) == 2 + l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16 + l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4 + assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 + + # compound case + l0_0, l0_1 = isplit(l0, 16) + l1_0, l1_1 = isplit(l1, 4) + l2_1, l2_2 = isplit(l2, 4) + l2_0, l2_1 = isplit(l2_1, 4) + + l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0]) + + iter_map = tvm.arith.detect_iter_map([l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], + var_dom([l0, l1, l2])) + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 3 + l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16 + l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4 + l2_inverse = floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 \ + + outputs[2] + + assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 + assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0 + + # diamond-shape DAG + l0_0, l0_1 = isplit(l0, 16) + l1 = ifuse([l0_1, l0_0]) + l1_0, l1_1 = isplit(l1, 8) + l2 = ifuse([l1_1, l1_0]) + + iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])) + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 1 + l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 8), 8) + l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 4), 16) + + assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + + if __name__ == "__main__": test_split() test_trivial() @@ -725,3 +786,4 @@ def test_normalize_iter_map_to_expr(): test_normalize_iter_map_to_expr() test_subspace_division() test_complex() + test_inverse_affine_iter_map()