Skip to content

Commit

Permalink
Skip the visitor pattern for a few heavily visited node types
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet committed Dec 31, 2023
1 parent 351fd54 commit 67b3182
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
4 changes: 2 additions & 2 deletions apps/performance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ pipeline make_pipeline(bool explicit_y) {
auto in = buffer_expr::make(ctx, "in", sizeof(char), 2);
auto out = buffer_expr::make(ctx, "out", sizeof(char), 2);

expr x = make_variable(ctx, "x");
expr y = make_variable(ctx, "y");
var x(ctx, "x");
var y(ctx, "y");

func copy = func::make<const char, char>(::copy<char>, {in, {point(x), point(y)}}, {out, {x, y}});

Expand Down
11 changes: 10 additions & 1 deletion src/evaluate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,18 @@ class evaluator : public node_visitor {

evaluator(eval_context& context) : context(context) {}

// Skip the visitor pattern (two virtual function calls) for some frequently used node types.
void visit(const expr& x) {
switch (x.type()) {
case node_type::variable: visit(reinterpret_cast<const variable*>(x.get())); return;
case node_type::constant: visit(reinterpret_cast<const constant*>(x.get())); return;
default: x.accept(this);
}
}

// Assume `e` is defined, evaluate it and return the result.
index_t eval_expr(const expr& e) {
e.accept(this);
visit(e);
index_t r = result;
result = 0;
return r;
Expand Down
20 changes: 13 additions & 7 deletions src/substitute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ class matcher : public node_visitor {
return match == 0;
}

// Skip the visitor pattern (two virtual function calls) for a few node types that are very frequently visited.
void visit(const expr& x) {
switch (x.type()) {
case node_type::add: visit(reinterpret_cast<const add*>(x.get())); return;
case node_type::min: visit(reinterpret_cast<const class min*>(x.get())); return;
case node_type::max: visit(reinterpret_cast<const class max*>(x.get())); return;
default: x.accept(this);
}
}

bool try_match(const expr& e, const expr& x) {
if (!e.defined() && !x.defined()) {
match = 0;
Expand All @@ -48,7 +58,7 @@ class matcher : public node_visitor {
match = 1;
} else {
self = e.get();
x.accept(this);
visit(x);
}
return match == 0;
}
Expand Down Expand Up @@ -408,12 +418,8 @@ class substitutor : public node_mutator {

} // namespace

expr substitute(const expr& e, const symbol_map<expr>& replacements) {
return substitutor(replacements).mutate(e);
}
stmt substitute(const stmt& s, const symbol_map<expr>& replacements) {
return substitutor(replacements).mutate(s);
}
expr substitute(const expr& e, const symbol_map<expr>& replacements) { return substitutor(replacements).mutate(e); }
stmt substitute(const stmt& s, const symbol_map<expr>& replacements) { return substitutor(replacements).mutate(s); }

expr substitute(const expr& e, symbol_id target, const expr& replacement) {
return substitutor(target, replacement).mutate(e);
Expand Down

0 comments on commit 67b3182

Please sign in to comment.