Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Remove pipeline parallelism (#555)" #556

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion base/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,19 @@ void thread_pool_impl::run_worker(const predicate& condition) {
--worker_count_;
}

namespace {

thread_local std::vector<thread_pool::task_id> task_stack;

} // namespace

thread_pool::task_id thread_pool_impl::dequeue(task& t) {
for (auto i = task_queue_.begin(); i != task_queue_.end(); ++i) {
task_id id = std::get<2>(*i);
if (id != unique_task_id && std::find(task_stack.begin(), task_stack.end(), id) != task_stack.end()) {
// Don't enqueue the same task multiple times on the same thread.
continue;
}
int& task_count = std::get<0>(*i);
if (task_count == 1) {
t = std::move(std::get<1>(*i));
Expand All @@ -61,9 +71,11 @@ void thread_pool_impl::wait_for(const thread_pool::predicate& condition, std::co
std::unique_lock l(mutex_);
while (!condition()) {
task t;
if (dequeue(t)) {
if (task_id id = dequeue(t)) {
l.unlock();
task_stack.push_back(id);
t();
task_stack.pop_back();
l.lock();
// Notify the helper CV, helpers might be waiting for a condition that the task changed the value of.
cv_helper_.notify_all();
Expand Down Expand Up @@ -102,7 +114,10 @@ void thread_pool_impl::enqueue(task t, task_id id) {
}

void thread_pool_impl::run(const task& t, task_id id) {
assert(id == unique_task_id || std::find(task_stack.begin(), task_stack.end(), id) == task_stack.end());
task_stack.push_back(id);
t();
task_stack.pop_back();
}

void thread_pool_impl::cancel(task_id id) {
Expand Down
16 changes: 15 additions & 1 deletion builder/simplify_exprs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,21 @@ expr simplify(const call* op, intrinsic fn, std::vector<expr> args) {
changed = changed || !args[i].same_as(op->args[i]);
}

if (fn == intrinsic::buffer_at) {
if (fn == intrinsic::semaphore_init || fn == intrinsic::semaphore_wait || fn == intrinsic::semaphore_signal) {
assert(args.size() % 2 == 0);
for (std::size_t i = 0; i < args.size();) {
// Remove calls to undefined semaphores.
if (!args[i].defined()) {
args.erase(args.begin() + i, args.begin() + i + 2);
changed = true;
} else {
i += 2;
}
}
if (args.empty()) {
return expr();
}
} else if (fn == intrinsic::buffer_at) {
for (index_t d = 1; d < static_cast<index_t>(args.size()); ++d) {
auto buf = as_variable(args[0]);
assert(buf);
Expand Down
138 changes: 133 additions & 5 deletions builder/slide_and_fold_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,41 @@ class slide_and_fold : public stmt_mutator {
std::unique_ptr<symbol_map<modulus_remainder<index_t>>> expr_alignment =
std::make_unique<symbol_map<modulus_remainder<index_t>>>();

// The next few fields relate to implementing synchronization in pipelined loops. In a pipelined loop, we
// treat a sequence of stmts as "stages" in the pipeline, where we add synchronization to cause the loop
// to appear to be executed serially to the stages: a stage can assume the same stage for a previous iteration has
// completed, and can assume that all previous stages for the same iteration have completed.
var semaphores;
var worker_count;

// How many stages we've added synchronization for in total so far.
int sync_stages = 0;
// We only track the stage we're currently working on. This optional being present indicates the current stage needs
// synchronization, and the value indicates which stage it is.
std::optional<int> stage;

// Unique loop ID.
std::size_t loop_id = -1;

bool add_synchronization() {
if (sync_stages + 1 >= max_workers) {
// It's pointless to add more stages to the loop, because we can't run then in parallel anyways, it would just
// add more synchronization overhead.
return false;
}

// We need synchronization, but we might already have it.
if (!stage) {
stage = sync_stages++;
}
return true;
}

loop_info(node_context& ctx, var sym, std::size_t loop_id, expr orig_min, interval_expr bounds, expr step,
int max_workers)
: sym(sym), orig_min(orig_min), bounds(bounds), step(step), max_workers(max_workers), loop_id(loop_id) {}
: sym(sym), orig_min(orig_min), bounds(bounds), step(step), max_workers(max_workers),
semaphores(ctx, ctx.name(sym) + "_semaphores"), worker_count(ctx, ctx.name(sym) + "_worker_count"),
loop_id(loop_id) {}
};
std::vector<loop_info> loops;

Expand All @@ -217,6 +246,29 @@ class slide_and_fold : public stmt_mutator {
loops.emplace_back(ctx, var(), loop_counter++, expr(), interval_expr::none(), expr(), loop::serial);
}

stmt mutate(const stmt& s) override {
stmt result = stmt_mutator::mutate(s);

// The loop at the back of the loops vector is the immediately containing loop. So, we know there are no
// intervening loops, and we can add any synchronization that has been requested. Doing so completes the current
// pipeline stage.
loop_info& l = loops.back();
if (l.stage) {
result = block::make({
// Wait for the previous iteration of this stage to complete.
// The l.sym here is equal to l.min + x * l.step, so dividing l.sym by l.step we get floor_div(l.min) + x.
// This works even if l.min is not divisible by l.step, because it remains constant w.r.t to the loop index.
check::make(semaphore_wait(buffer_at(l.semaphores, *l.stage, floor_div(expr(l.sym), l.step) - 1))),
result,
// Signal we've done this iteration.
check::make(semaphore_signal(buffer_at(l.semaphores, *l.stage, floor_div(expr(l.sym), l.step)))),
});
l.stage = std::nullopt;
}

return result;
}

void visit(const let_stmt* op) override {
auto& bounds = current_expr_bounds();
std::vector<scoped_value_in_symbol_map<interval_expr>> values;
Expand Down Expand Up @@ -321,9 +373,6 @@ class slide_and_fold : public stmt_mutator {
fold_factor = simplify(constant_upper_bound(fold_factor), *loop.expr_bounds, *loop.expr_alignment);
if (is_finite(fold_factor) && !depends_on(fold_factor, loop.sym).any()) {
vector_at(fold_factors[output], d) = {fold_factor, fold_factor, loops.back().loop_id};

// This loop has a dependency between loop iterations, mark it as not data parallel.
loop.data_parallel = false;
} else {
// The fold factor didn't simplify to something that doesn't depend on the loop variable.
}
Expand Down Expand Up @@ -412,6 +461,7 @@ class slide_and_fold : public stmt_mutator {
for (var output : outputs) {
for (loop_info& loop : loops) {
if (!fold_factors[output]) continue;
loop.add_synchronization();

expr loop_var = variable::make(loop.sym);
for (int d = 0; d < static_cast<int>(fold_factors[output]->size()); ++d) {
Expand All @@ -424,6 +474,24 @@ class slide_and_fold : public stmt_mutator {
if (!is_finite(fold_factor)) {
continue;
}

if (!depends_on(fold_factor, loop.sym).any()) {
// We need an extra fold per worker when parallelizing the loop.
// TODO: This extra folding seems excessive, it allows all workers to execute any stage.
// If we can figure out how to add some synchronization to limit the number of workers that
// work on a single stage at a time, we should be able to reduce this extra folding.
// TODO: In this case, we currently need synchronization, but we should find a way to eliminate it.
// This synchronization will cause the loop to run only as fast as the slowest stage, which is
// unnecessary in the case of a fully data parallel loop. In order to avoid this, we need to avoid race
// conditions. The synchronization avoids the race condition by only allowing a window of max_workers to
// run at once, so the storage folding here works as intended. If we could instead find a way to give
// each worker its own slice of this buffer, we could avoid this synchronization. I think this might be
// doable by making the worker index available to the loop body, and using that to grab a slice of this
// buffer, so each worker can get its own fold.

fold_factor += (loop.worker_count - 1) * (*fold_factors[output])[d].overlap;
vector_at(fold_factors[output], d).factor = simplify(fold_factor);
}
}
}
}
Expand Down Expand Up @@ -584,9 +652,69 @@ class slide_and_fold : public stmt_mutator {
}

const loop_info& l = loops.back();
const int max_workers = l.data_parallel ? op->max_workers : 1;
const int stage_count = l.sync_stages;
const int max_workers = l.data_parallel ? op->max_workers : std::max(1, stage_count);
stmt result = loop::make(op->sym, max_workers, loop_bounds, op->step, std::move(body));

// Substitute the placeholder worker_count.
result = substitute(result, l.worker_count, max_workers);
// We need to do this in the fold factors too.
for (std::optional<std::vector<dim_fold_info>>& i : fold_factors) {
if (!i) continue;
for (dim_fold_info& j : *i) {
if (!depends_on(j.factor, l.worker_count).any()) continue;

if (l.data_parallel && max_workers == loop::parallel) {
// This is a data parallel loop, remove the folding.
// TODO: We have other options that would be better:
// - Move the allocation into the loop.
// - Rewrite accesses to this dimension to be a function of a thread ID (and rewrite the fold factor to the
// max thread ID).
j.factor = expr();
} else {
// This is a serial or pipelined loop, we can still fold.
j.factor = substitute(j.factor, l.worker_count, max_workers);
}
}
}

if (!l.data_parallel && stage_count > 1) {
// We added synchronization in the loop, we need to allocate a buffer for the semaphores.
interval_expr sem_bounds = {0, stage_count - 1};

index_t sem_size = sizeof(index_t);
call_stmt::attributes init_sems_attrs;
init_sems_attrs.name = "init_semaphores";
stmt init_sems = call_stmt::make(
[stage_count](const call_stmt* s, eval_context& ctx) -> index_t {
const buffer<index_t>& sems = *ctx.lookup_buffer<index_t>(s->outputs[0]);
assert(sems.rank == 2);
assert(sems.dim(0).min() == 0);
assert(sems.dim(0).extent() == stage_count);
memset(sems.base(), 0, sems.size_bytes());
// Initialize the first semaphore for each stage (the one before the loop min) to 1,
// unblocking the first iteration.
assert(sems.dim(0).stride() == sizeof(index_t));
std::fill_n(&sems(0, sems.dim(1).min()), stage_count, 1);
return 0;
},
{}, {l.semaphores}, std::move(init_sems_attrs));
// We can fold the semaphores array by the number of threads we'll use.
// TODO: Use the loop index and not the loop variable directly for semaphores so we don't need to do this.
expr sem_fold_factor = stage_count;
std::vector<dim_expr> sem_dims = {
{sem_bounds, sem_size},
// TODO: We should just let dimensions like this have undefined bounds.
{{floor_div(loop_bounds.min, op->step) - 1, floor_div(loop_bounds.max, op->step)},
sem_size * sem_bounds.extent(), sem_fold_factor},
};
result = allocate::make(
l.semaphores, memory_type::stack, sem_size, std::move(sem_dims), block::make({init_sems, result}));
} else {
// We only have one stage, there's no need for semaphores.
result = substitute(result, l.semaphores, expr());
}

if (!is_variable(loop_bounds.min, orig_min) || depends_on(result, orig_min).any()) {
// We rewrote or used the loop min.
result = let_stmt::make(orig_min, op->bounds.min, result);
Expand Down
8 changes: 5 additions & 3 deletions builder/test/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,8 @@ TEST_P(stencil, pipeline) {
}

if (split > 0) {
const int intm_size = (W + 2) * (split + 2) * sizeof(short);
const int parallel_extra = max_workers != loop::serial ? split : 0;
const int intm_size = (W + 2) * (split + parallel_extra + 2) * sizeof(short);
ASSERT_THAT(eval_ctx.heap.allocs, testing::UnorderedElementsAre(intm_size));
} else {
ASSERT_EQ(eval_ctx.heap.allocs.size(), 1);
Expand Down Expand Up @@ -623,8 +624,9 @@ TEST_P(stencil_chain, pipeline) {
}

if (split > 0) {
const int intm_size = (W + 2) * (split + 2) * sizeof(short);
const int intm2_size = (W + 4) * (split + 2) * sizeof(short);
const int parallel_extra = max_workers != loop::serial ? split * 2 : 0;
const int intm_size = (W + 2) * (split + parallel_extra + 2) * sizeof(short);
const int intm2_size = (W + 4) * (split + parallel_extra + 2) * sizeof(short);
ASSERT_THAT(eval_ctx.heap.allocs, testing::UnorderedElementsAre(intm_size, intm2_size));
} else {
ASSERT_EQ(eval_ctx.heap.allocs.size(), 2);
Expand Down
3 changes: 2 additions & 1 deletion builder/test/pyramid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ TEST_P(pyramid, pipeline) {
test_context eval_ctx;
p.evaluate(inputs, outputs, eval_ctx);

ASSERT_THAT(eval_ctx.heap.allocs, testing::UnorderedElementsAre((W + 2) / 2 * 2 * sizeof(int)));
const int parallel_extra = max_workers != loop::serial ? 1 : 0;
ASSERT_THAT(eval_ctx.heap.allocs, testing::UnorderedElementsAre((W + 2) / 2 * (2 + parallel_extra) * sizeof(int)));

if (max_workers == loop::serial) {
check_replica_pipeline(define_replica_pipeline(ctx, {in}, {out}));
Expand Down
52 changes: 52 additions & 0 deletions runtime/evaluate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,54 @@ class evaluator {
return result;
}

index_t eval_semaphore_init(const call* op) {
assert(op->args.size() == 2);
index_t* sem = reinterpret_cast<index_t*>(eval(op->args[0]));
index_t count = eval(op->args[1], 0);
context.thread_pool->atomic_call([=]() { *sem = count; });
return 1;
}

SLINKY_NO_STACK_PROTECTOR index_t eval_semaphore_signal(const call* op) {
assert(op->args.size() % 2 == 0);
std::size_t sem_count = op->args.size() / 2;
index_t** sems = SLINKY_ALLOCA(index_t*, sem_count);
index_t* counts = SLINKY_ALLOCA(index_t, sem_count);
for (std::size_t i = 0; i < sem_count; ++i) {
sems[i] = reinterpret_cast<index_t*>(eval(op->args[i * 2 + 0]));
counts[i] = eval(op->args[i * 2 + 1], 1);
}
context.thread_pool->atomic_call([=]() {
for (std::size_t i = 0; i < sem_count; ++i) {
*sems[i] += counts[i];
}
});
return 1;
}

SLINKY_NO_STACK_PROTECTOR index_t eval_semaphore_wait(const call* op) {
assert(op->args.size() % 2 == 0);
std::size_t sem_count = op->args.size() / 2;
index_t** sems = SLINKY_ALLOCA(index_t*, sem_count);
index_t* counts = SLINKY_ALLOCA(index_t, sem_count);
for (std::size_t i = 0; i < sem_count; ++i) {
sems[i] = reinterpret_cast<index_t*>(eval(op->args[i * 2 + 0]));
counts[i] = eval(op->args[i * 2 + 1], 1);
}
context.thread_pool->wait_for([=]() {
// Check we can acquire all of the semaphores before acquiring any of them.
for (std::size_t i = 0; i < sem_count; ++i) {
if (*sems[i] < counts[i]) return false;
}
// Acquire them all.
for (std::size_t i = 0; i < sem_count; ++i) {
*sems[i] -= counts[i];
}
return true;
});
return 1;
}

index_t eval_trace_begin(const call* op) {
assert(op->args.size() == 1);
const char* name = reinterpret_cast<const char*>(eval(op->args[0]));
Expand Down Expand Up @@ -275,6 +323,10 @@ class evaluator {
case intrinsic::buffer_size_bytes: return eval_buffer_metadata(op);
case intrinsic::buffer_at: return reinterpret_cast<index_t>(eval_buffer_at(op));

case intrinsic::semaphore_init: return eval_semaphore_init(op);
case intrinsic::semaphore_signal: return eval_semaphore_signal(op);
case intrinsic::semaphore_wait: return eval_semaphore_wait(op);

case intrinsic::trace_begin: return eval_trace_begin(op);
case intrinsic::trace_end: return eval_trace_end(op);

Expand Down
14 changes: 14 additions & 0 deletions runtime/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ enum class intrinsic {
// This function returns the address of the element x in (buf, x_0, x_1, ...). x can be any rank, including 0.
buffer_at,

// These functions implement counting semaphores.
// The first argument of all of these semaphore helpers is a pointer to an index_t that will be used as the semaphore,
// and the second argument is a count.
semaphore_init,
// wait and signal can take multiple semaphores as a sequence of (sem, count) pairs of arguments.
semaphore_signal,
semaphore_wait,

// Calls the tracing callback with the first argument. Returns a token that should be passed to end_trace.
trace_begin,
trace_end,
Expand Down Expand Up @@ -682,6 +690,12 @@ expr buffer_at(expr buf, expr at0, Args... at) {

box_expr dims_bounds(span<const dim_expr> dims);

expr semaphore_init(expr sem, expr count = expr());
expr semaphore_signal(expr sem, expr count = expr());
expr semaphore_signal(span<const expr> sems, span<const expr> counts = {});
expr semaphore_wait(expr sem, expr count = expr());
expr semaphore_wait(span<const expr> sems, span<const expr> counts = {});

template <typename T>
class symbol_map {
std::vector<std::optional<T>> values;
Expand Down
Loading
Loading