Skip to content

Commit

Permalink
Small optimizations (#549)
Browse files Browse the repository at this point in the history
- Add a proper "unreachable" code function/macro
- Optimize `init_strides` by maintaining a sorted list of dimensions
- Compute buffer size in `init_strides`, so allocation uses don't need
to call both functions
- Add some guesses to likely nodes types in evaluate
- Lift growing the context "memory" out of loops
- Tweak inlining of evaluate implementations
- Remove checks later, so the second `simplify` can see them
- Add `canonicalize_nodes`
- Tweaks to `for_each_element`
  • Loading branch information
dsharlet authored Jan 13, 2025
1 parent 1a4fcba commit f7f6b58
Show file tree
Hide file tree
Showing 21 changed files with 443 additions and 231 deletions.
43 changes: 43 additions & 0 deletions base/util.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef SLINKY_BASE_UTIL_H
#define SLINKY_BASE_UTIL_H

#include <iostream>

namespace slinky {

// Some functions are templates that are usually unique specializations, which are beneficial to inline. The compiler
Expand All @@ -23,6 +25,12 @@ namespace slinky {
#define SLINKY_TRIVIAL_ABI
#endif

#if SLINKY_HAS_ATTRIBUTE(pure)
#define SLINKY_PURE __attribute__((pure))
#else
#define SLINKY_PURE
#endif

#ifdef NDEBUG
// alloca() will cause stack-smashing code to be inserted;
// while laudable, we use alloca() in time-critical code
Expand All @@ -32,6 +40,41 @@ namespace slinky {
#define SLINKY_NO_STACK_PROTECTOR /* nothing */
#endif

#if defined(__GNUC__)
#define SLINKY_LIKELY(condition) (__builtin_expect(!!(condition), 1))
#define SLINKY_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
#else
#define SLINKY_LIKELY(condition) (!!(condition))
#define SLINKY_UNLIKELY(condition) (!!(condition))
#endif

class unreachable {
public:
unreachable() = default;
[[noreturn]] ~unreachable() {
#ifndef NDEBUG
std::abort();
#else
// https://en.cppreference.com/w/cpp/utility/unreachable
#if defined(_MSC_VER) && !defined(__clang__)
__assume(false);
#else
__builtin_unreachable();
#endif
#endif
}

template <typename T>
unreachable& operator<<(const T& x) {
#ifndef NDEBUG
std::cerr << x;
#endif
return *this;
}
};

#define SLINKY_UNREACHABLE unreachable() << "unreachable executed at " << __FILE__ << ", " << __LINE__ << ": "

} // namespace slinky

#endif // SLINKY_BASE_UTIL_H
31 changes: 30 additions & 1 deletion builder/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ class copy_aliaser : public stmt_mutator {

void visit(const transpose*) override {
// TODO: We should be able to handle this.
std::abort();
SLINKY_UNREACHABLE << "transpose not handled by buffer_aliaser";
}
};

Expand Down Expand Up @@ -1061,4 +1061,33 @@ stmt optimize_symbols(const stmt& s, node_context& ctx) {
return reuse_shadows().mutate(s);
}

namespace {

class node_canonicalizer : public node_mutator {
std::map<expr, expr, node_less> exprs;
std::map<stmt, stmt, node_less> stmts;

public:
using node_mutator::mutate;

stmt mutate(const stmt& s) override {
stmt& result = stmts[s];
if (!result.defined()) result = node_mutator::mutate(s);
return result;
}

expr mutate(const expr& e) override {
expr& result = exprs[e];
if (!result.defined()) result = node_mutator::mutate(e);
return result;
}
};

} // namespace

stmt canonicalize_nodes(const stmt& s) {
scoped_trace trace("canonicalize_nodes");
return node_canonicalizer().mutate(s);
}

} // namespace slinky
4 changes: 4 additions & 0 deletions builder/optimizations.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ stmt deshadow(const stmt& s, span<var> external_symbols, node_context& ctx);
// - Symbols are indexed such that there are no unused symbol indices.
stmt optimize_symbols(const stmt& s, node_context& ctx);

// Guarantees that if match(a, b) is true, then a.same_as(b) is true, i.e. it rewrites matching nodes to be the same
// object.
stmt canonicalize_nodes(const stmt& s);

} // namespace slinky

#endif // SLINKY_BUILDER_OPTIMIZATIONS_H
12 changes: 8 additions & 4 deletions builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1020,14 +1020,16 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input
// `evaluate` currently can't handle `copy_stmt`, so this is required.
result = implement_copies(result, ctx);

if (options.no_checks) {
result = recursive_mutate<check>(result, [](const check* op) { return stmt(); });
}

// `implement_copies` adds shadowed declarations, remove them before simplifying.
result = deshadow(result, builder.external_symbols(), ctx);
result = simplify(result);

if (options.no_checks) {
result = recursive_mutate<check>(result, [](const check* op) { return stmt(); });
// Simplify again, in case there are lets that the checks used that are now dead.
result = simplify(result);
}

result = optimize_symbols(result, ctx);

result = insert_early_free(result);
Expand All @@ -1036,6 +1038,8 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input
result = inject_traces(result, ctx, constants);
}

result = canonicalize_nodes(result);

if (is_verbose()) {
std::cout << result << std::endl;
}
Expand Down
10 changes: 3 additions & 7 deletions builder/replica_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ class pipeline_replicator : public expr_visitor {
public:
explicit pipeline_replicator(node_context& ctx) : ctx_(ctx) {}

void fail(const char* msg) {
std::cerr << "Unimplemented/TODO: " << msg << "\n";
std::abort();
}

void visit(const variable* op) override {
const std::string& name = ctx_.name(op->sym);

Expand All @@ -82,7 +77,7 @@ class pipeline_replicator : public expr_visitor {
}

void visit(const constant* op) override { name_ = to_string(op->value); }
void visit(const let* op) override { fail("unimplemented let"); }
void visit(const let* op) override { SLINKY_UNREACHABLE; }
void visit(const add* op) override { visit_binary_op(op, "+"); }
void visit(const sub* op) override { visit_binary_op(op, "-"); }
void visit(const mul* op) override { visit_binary_op(op, "*"); }
Expand Down Expand Up @@ -616,7 +611,8 @@ struct rph_handler {
case 0x82: DO_XOR(uint64_t, uint16_t); break;
case 0x84: DO_XOR(uint64_t, uint32_t); break;
case 0x88: DO_XOR(uint64_t, uint64_t); break;
default: std::cerr << "Unsupported elem_size combination\n"; std::abort();

default: SLINKY_UNREACHABLE << "Unsupported elem_size combination";
}

#undef DO_XOR
Expand Down
4 changes: 2 additions & 2 deletions builder/rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ std::ostream& operator<<(std::ostream& os, const pattern_binary<T, A, B>& p) {
case not_equal::static_type: return os << '(' << p.a << " != " << p.b << ')';
case logical_and::static_type: return os << '(' << p.a << " && " << p.b << ')';
case logical_or::static_type: return os << '(' << p.a << " || " << p.b << ')';
default: std::abort();
default: SLINKY_UNREACHABLE << "unknown binary operator " << to_string(T::static_type);
}
}

Expand Down Expand Up @@ -279,7 +279,7 @@ template <typename T, typename A>
SLINKY_UNIQUE std::ostream& operator<<(std::ostream& os, const pattern_unary<T, A>& p) {
switch (T::static_type) {
case logical_not::static_type: return os << '!' << p.a;
default: std::abort();
default: SLINKY_UNREACHABLE << "unknown unary operator " << to_string(T::static_type);
}
}

Expand Down
2 changes: 1 addition & 1 deletion builder/slide_and_fold_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class slide_and_fold : public stmt_mutator {
set_result(clone_with(op, std::move(body)));
}
}
void visit(const transpose*) override { std::abort(); }
void visit(const transpose*) override { SLINKY_UNREACHABLE << "transpose not handled by slide_and_fold_storage"; }
void visit(const clone_buffer* op) override {
auto set_alias = set_value_in_scope(aliases, op->sym, op->src);
stmt_mutator::visit(op);
Expand Down
7 changes: 3 additions & 4 deletions builder/substitute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,7 @@ class buffer_substitutor : public substitutor {
var enter_decl(var x) override { return x != target ? x : var(); }

stmt mutate(const stmt& s) override {
// We don't support substituting buffers into stmts.
std::abort();
SLINKY_UNREACHABLE << "can't substitute buffer into stmt";
}
dim_expr mutate(const dim_expr& e) { return {mutate(e.bounds), mutate(e.stride), mutate(e.fold_factor)}; }
using substitutor::mutate;
Expand All @@ -733,7 +732,7 @@ class buffer_substitutor : public substitutor {
case buffer_field::stride:
case buffer_field::fold_factor:
return dim < static_cast<index_t>(dims.size()) ? dims[dim].get_field(field) : expr(op);
case buffer_field::none: std::abort();
default: SLINKY_UNREACHABLE << "got scalar var instead of buffer";
}
return expr(op);
}
Expand Down Expand Up @@ -800,7 +799,7 @@ class expr_substitutor : public node_mutator {
}
stmt mutate(const stmt& op) override {
// We don't support substituting exprs into stmts.
std::abort();
SLINKY_UNREACHABLE << "can't substitute expr into stmt";
}
using node_mutator::mutate;
};
Expand Down
12 changes: 6 additions & 6 deletions builder/test/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ class elementwise_pipeline_builder : public expr_visitor {
result_funcs.push_back(std::move(r));
}

void visit(const let*) override { std::abort(); }
void visit(const call*) override { std::abort(); }
void visit(const logical_not*) override { std::abort(); }
void visit(const let*) override { SLINKY_UNREACHABLE; }
void visit(const call*) override { SLINKY_UNREACHABLE; }
void visit(const logical_not*) override { SLINKY_UNREACHABLE; }
};

template <typename T, std::size_t Rank>
Expand Down Expand Up @@ -203,9 +203,9 @@ class elementwise_pipeline_evaluator : public expr_visitor {
for_each_element([&](T* result, const T* c, const T* t) { *result = *c ? *t : *result; }, result, c_buf, t_buf);
}

void visit(const let*) override { std::abort(); }
void visit(const call*) override { std::abort(); }
void visit(const logical_not*) override { std::abort(); }
void visit(const let*) override { SLINKY_UNREACHABLE; }
void visit(const call*) override { SLINKY_UNREACHABLE; }
void visit(const logical_not*) override { SLINKY_UNREACHABLE; }
};

template <typename T, std::size_t Rank>
Expand Down
4 changes: 2 additions & 2 deletions builder/test/simplify/expr_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class expr_generator {
case 4: return rng_() % 8 != 0 ? ac() && bc() : and_then(ac(), bc());
case 5: return rng_() % 8 != 0 ? ac() || bc() : or_else(ac(), bc());
case 6: return !random_condition(depth - 1);
default: std::abort();
default: SLINKY_UNREACHABLE;
}
}

Expand All @@ -88,7 +88,7 @@ class expr_generator {
case 8: return random_constant();
case 9: return random_variable();
case 10: return random_condition(depth);
default: std::abort();
default: SLINKY_UNREACHABLE;
}
}
}
Expand Down
Loading

0 comments on commit f7f6b58

Please sign in to comment.