Skip to content

Commit

Permalink
Add clone_buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet committed Jan 11, 2024
1 parent 66b350b commit 350bb69
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 4 deletions.
15 changes: 15 additions & 0 deletions src/evaluate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,21 @@ class evaluator : public node_visitor {
visit(op->body);
}

void visit(const clone_buffer* op) override {
raw_buffer* src = reinterpret_cast<raw_buffer*>(*context.lookup(op->sym));
char* storage = reinterpret_cast<char*>(alloca(sizeof(raw_buffer) + sizeof(dim) * src->rank));

raw_buffer* buffer = reinterpret_cast<raw_buffer*>(&storage[0]);
buffer->dims = reinterpret_cast<dim*>(&storage[sizeof(raw_buffer)]);
buffer->elem_size = src->elem_size;
buffer->base = src->base;
buffer->rank = src->rank;
memcpy(buffer->dims, src->dims, sizeof(dim) * src->rank);

auto set_buffer = set_value_in_scope(context, op->sym, reinterpret_cast<index_t>(buffer));
visit(op->body);
}

void visit(const crop_buffer* op) override {
raw_buffer* buffer = reinterpret_cast<raw_buffer*>(*context.lookup(op->sym));
assert(buffer);
Expand Down
8 changes: 8 additions & 0 deletions src/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,14 @@ stmt make_buffer::make(symbol_id sym, expr base, expr elem_size, std::vector<dim
return n;
}

stmt clone_buffer::make(symbol_id sym, symbol_id src, stmt body) {
auto n = new clone_buffer();
n->sym = sym;
n->src = src;
n->body = std::move(body);
return n;
}

stmt crop_buffer::make(symbol_id sym, std::vector<interval_expr> bounds, stmt body) {
auto n = new crop_buffer();
n->sym = sym;
Expand Down
22 changes: 22 additions & 0 deletions src/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ enum class node_type {
if_then_else,
allocate,
make_buffer,
clone_buffer,
crop_buffer,
crop_dim,
slice_buffer,
Expand Down Expand Up @@ -607,6 +608,22 @@ class make_buffer : public stmt_node<make_buffer> {
static constexpr node_type static_type = node_type::make_buffer;
};

// Makes a clone of an existing buffer.
// TODO: This basically only exists because we cannot use `make_buffer` to clone a buffer of unknown rank. Maybe there's
// a better way to do this.
class clone_buffer : public stmt_node<clone_buffer> {
public:
symbol_id sym;
symbol_id src;
stmt body;

void accept(node_visitor* v) const;

static stmt make(symbol_id sym, symbol_id src, stmt body);

static constexpr node_type static_type = node_type::clone_buffer;
};

// For the `body` scope, crops the buffer `sym` to `bounds`. If the expressions in `bounds` are undefined, they default
// to their original values in the existing buffer. The rank of the buffer is unchanged. If the size of `bounds` is less
// than the rank, the missing values are considered undefined.
Expand Down Expand Up @@ -730,6 +747,7 @@ class node_visitor {
virtual void visit(const copy_stmt*) = 0;
virtual void visit(const allocate*) = 0;
virtual void visit(const make_buffer*) = 0;
virtual void visit(const clone_buffer*) = 0;
virtual void visit(const crop_buffer*) = 0;
virtual void visit(const crop_dim*) = 0;
virtual void visit(const slice_buffer*) = 0;
Expand Down Expand Up @@ -824,6 +842,9 @@ class recursive_node_visitor : public node_visitor {
}
op->body.accept(this);
}
virtual void visit(const clone_buffer* op) override {
op->body.accept(this);
}
virtual void visit(const crop_buffer* op) override {
for (const interval_expr& i : op->bounds) {
if (i.min.defined()) i.min.accept(this);
Expand Down Expand Up @@ -879,6 +900,7 @@ inline void call_stmt::accept(node_visitor* v) const { v->visit(this); }
inline void copy_stmt::accept(node_visitor* v) const { v->visit(this); }
inline void allocate::accept(node_visitor* v) const { v->visit(this); }
inline void make_buffer::accept(node_visitor* v) const { v->visit(this); }
inline void clone_buffer::accept(node_visitor* v) const { v->visit(this); }
inline void crop_buffer::accept(node_visitor* v) const { v->visit(this); }
inline void crop_dim::accept(node_visitor* v) const { v->visit(this); }
inline void slice_buffer::accept(node_visitor* v) const { v->visit(this); }
Expand Down
5 changes: 5 additions & 0 deletions src/infer_bounds.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,11 @@ class slide_and_fold_storage : public node_mutator {
void visit(const truncate_rank*) override { std::abort(); }

void visit(const loop* op) override {
if (op->mode == loop_mode::parallel) {
// Don't try sliding window or storage folding on parallel loops.
node_mutator::visit(op);
return;
}
var orig_min(ctx, ctx.name(op->sym) + "_min.orig");

loops.emplace_back(op->sym, orig_min, bounds(orig_min, op->bounds.max), op->step);
Expand Down
11 changes: 11 additions & 0 deletions src/node_mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ stmt clone_with_new_body(const allocate* op, stmt new_body) {
stmt clone_with_new_body(const make_buffer* op, stmt new_body) {
return make_buffer::make(op->sym, op->base, op->elem_size, op->dims, std::move(new_body));
}
stmt clone_with_new_body(const clone_buffer* op, stmt new_body) {
return clone_buffer::make(op->sym, op->src, std::move(new_body));
}
stmt clone_with_new_body(const crop_buffer* op, stmt new_body) {
return crop_buffer::make(op->sym, op->bounds, std::move(new_body));
}
Expand Down Expand Up @@ -187,6 +190,14 @@ void node_mutator::visit(const make_buffer* op) {
set_result(make_buffer::make(op->sym, std::move(base), std::move(elem_size), std::move(dims), std::move(body)));
}
}
void node_mutator::visit(const clone_buffer* op) {
stmt body = mutate(op->body);
if (body.same_as(op->body)) {
set_result(op);
} else {
set_result(clone_buffer::make(op->sym, op->src, std::move(body)));
}
}
void node_mutator::visit(const crop_buffer* op) {
std::vector<interval_expr> bounds;
bounds.reserve(op->bounds.size());
Expand Down
2 changes: 2 additions & 0 deletions src/node_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class node_mutator : public node_visitor {
virtual void visit(const copy_stmt*) override;
virtual void visit(const allocate*) override;
virtual void visit(const make_buffer*) override;
virtual void visit(const clone_buffer*) override;
virtual void visit(const crop_buffer*) override;
virtual void visit(const crop_dim*) override;
virtual void visit(const slice_buffer*) override;
Expand All @@ -82,6 +83,7 @@ class node_mutator : public node_visitor {
stmt clone_with_new_body(const let_stmt* op, stmt new_body);
stmt clone_with_new_body(const allocate* op, stmt new_body);
stmt clone_with_new_body(const make_buffer* op, stmt new_body);
stmt clone_with_new_body(const clone_buffer* op, stmt new_body);
stmt clone_with_new_body(const crop_buffer* op, stmt new_body);
stmt clone_with_new_body(const crop_dim* op, stmt new_body);
stmt clone_with_new_body(const slice_buffer* op, stmt new_body);
Expand Down
1 change: 1 addition & 0 deletions src/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ class scope_reducer : public node_mutator {
void visit(const let_stmt* op) override { visit_stmt(op); }
void visit(const allocate* op) override { visit_stmt(op); }
void visit(const make_buffer* op) override { visit_stmt(op); }
void visit(const clone_buffer* op) override { visit_stmt(op); }
void visit(const crop_buffer* op) override { visit_stmt(op); }
void visit(const crop_dim* op) override { visit_stmt(op); }
void visit(const slice_buffer* op) override { visit_stmt(op); }
Expand Down
6 changes: 6 additions & 0 deletions src/print.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,12 @@ class printer : public node_visitor {
*this << indent() << "}\n";
}

void visit(const clone_buffer* n) override {
*this << indent() << n->sym << " = clone_buffer(" << n->src << ") {\n";
*this << n->body;
*this << indent() << "}\n";
}

void visit(const crop_buffer* n) override {
*this << indent() << "crop_buffer(" << n->sym << ", {";
if (!n->bounds.empty()) {
Expand Down
17 changes: 13 additions & 4 deletions src/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class commute_variants : public node_visitor {
void visit(const copy_stmt* op) override { std::abort(); }
void visit(const allocate* op) override { std::abort(); }
void visit(const make_buffer* op) override { std::abort(); }
void visit(const clone_buffer* op) override { std::abort(); }
void visit(const crop_buffer* op) override { std::abort(); }
void visit(const crop_dim* op) override { std::abort(); }
void visit(const slice_buffer* op) override { std::abort(); }
Expand Down Expand Up @@ -273,6 +274,10 @@ bool is_buffer_mutated(symbol_id sym, const stmt& s) {
if (op->sym == sym) return;
recursive_node_visitor::visit(op);
}
void visit(const clone_buffer* op) override {
if (op->sym == sym) return;
recursive_node_visitor::visit(op);
}
void visit(const let_stmt* op) override {
if (op->sym == sym) return;
recursive_node_visitor::visit(op);
Expand Down Expand Up @@ -1413,13 +1418,17 @@ class simplifier : public node_mutator {
if (*src_buf == op->sym) {
set_result(mutate(truncate_rank::make(op->sym, dims.size(), std::move(body))));
return;
} else if (!is_buffer_mutated(op->sym, body) && !is_buffer_mutated(*src_buf, body)) {
const std::optional<box_expr>& src_bounds = buffer_bounds[*src_buf];
if (src_bounds && src_bounds->size() == dims.size()) {
}
const std::optional<box_expr>& src_bounds = buffer_bounds[*src_buf];
if (src_bounds && src_bounds->size() == dims.size()) {
if (!is_buffer_mutated(op->sym, body) && !is_buffer_mutated(*src_buf, body)) {
// This is a clone of src_buf, and we never mutate either buffer, we can just re-use it.
set_result(let_stmt::make(op->sym, buf, std::move(body)));
return;
} else {
// This is a clone of src_buf, but we've mutated one of them. Use clone_buffer instead.
set_result(clone_buffer::make(op->sym, *src_buf, std::move(body)));
}
return;
}
}
}
Expand Down
19 changes: 19 additions & 0 deletions src/substitute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,16 @@ class matcher : public node_visitor {
if (!try_match(mbs->body, op->body)) return;
}

void visit(const clone_buffer* op) override {
if (match) return;
const clone_buffer* mbs = match_self_as(op);
if (!mbs) return;

if (!try_match(mbs->sym, op->sym)) return;
if (!try_match(mbs->src, op->src)) return;
if (!try_match(mbs->body, op->body)) return;
}

void visit(const crop_buffer* op) override {
if (match) return;
const crop_buffer* cbs = match_self_as(op);
Expand Down Expand Up @@ -517,6 +527,15 @@ class substitutor : public node_mutator {
set_result(make_buffer::make(op->sym, std::move(base), std::move(elem_size), std::move(dims), std::move(body)));
}
}
void visit(const clone_buffer* op) override {
auto s = set_value_in_scope(shadowed, op->sym, true);
stmt body = mutate_decl_body(op->sym, op->body);
if (body.same_as(op->body)) {
set_result(op);
} else {
set_result(clone_buffer::make(op->sym, op->src, std::move(body)));
}
}
void visit(const slice_buffer* op) override {
std::vector<expr> at;
at.reserve(op->at.size());
Expand Down
2 changes: 2 additions & 0 deletions test/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class elementwise_pipeline_builder : public node_visitor {
void visit(const copy_stmt*) override { std::abort(); }
void visit(const allocate*) override { std::abort(); }
void visit(const make_buffer*) override { std::abort(); }
void visit(const clone_buffer*) override { std::abort(); }
void visit(const crop_buffer*) override { std::abort(); }
void visit(const crop_dim*) override { std::abort(); }
void visit(const slice_buffer*) override { std::abort(); }
Expand Down Expand Up @@ -226,6 +227,7 @@ class elementwise_pipeline_evaluator : public node_visitor {
void visit(const copy_stmt*) override { std::abort(); }
void visit(const allocate*) override { std::abort(); }
void visit(const make_buffer*) override { std::abort(); }
void visit(const clone_buffer*) override { std::abort(); }
void visit(const crop_buffer*) override { std::abort(); }
void visit(const crop_dim*) override { std::abort(); }
void visit(const slice_buffer*) override { std::abort(); }
Expand Down

0 comments on commit 350bb69

Please sign in to comment.