Skip to content

Commit

Permalink
[Feat] Uniformity analysis for SPIR-V (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Feb 2, 2025
1 parent e0734da commit 63a1c59
Show file tree
Hide file tree
Showing 14 changed files with 481 additions and 151 deletions.
2 changes: 2 additions & 0 deletions crates/cubecl-opt/src/analyses/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::{
dominance::{Dominators, PostDominators},
liveness::Liveness,
post_order::PostOrder,
uniformity::Uniformity,
};

/// An analysis used by optimization passes. Unlike optimization passes, analyses can have state
Expand Down Expand Up @@ -62,5 +63,6 @@ impl Optimizer {
self.invalidate_analysis::<Dominators>();
self.invalidate_analysis::<PostDominators>();
self.invalidate_analysis::<Liveness>();
self.invalidate_analysis::<Uniformity>();
}
}
1 change: 1 addition & 0 deletions crates/cubecl-opt/src/analyses/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod dominance;
pub mod integer_range;
pub mod liveness;
pub mod post_order;
pub mod uniformity;
pub mod writes;

pub use base::*;
239 changes: 239 additions & 0 deletions crates/cubecl-opt/src/analyses/uniformity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
use cubecl_ir::{
Builtin, Operation, OperationReflect, Plane, Synchronization, Variable, VariableKind,
};
use petgraph::{graph::EdgeIndex, visit::EdgeRef};
use std::collections::{HashMap, HashSet};

use crate::{ControlFlow, NodeIndex, Optimizer};

use super::Analysis;

#[derive(Default, Clone)]
pub struct Uniformity {
block_uniformity: HashMap<NodeIndex, bool>,
variable_uniformity: HashMap<Variable, bool>,
visited: HashSet<EdgeIndex>,
}

impl Analysis for Uniformity {
fn init(opt: &mut Optimizer) -> Self {
let mut this = Self::default();
this.run(opt);
this
}
}

impl Uniformity {
fn run(&mut self, opt: &Optimizer) {
let root = opt.entry();
self.block_uniformity.insert(root, true);
while self.analyze_block(opt, root).is_none() {}
}

fn analyze_block(&mut self, opt: &Optimizer, block_id: NodeIndex) -> Option<()> {
let block = opt.block(block_id);
let mut block_uniform = self.block_uniformity[&block_id];

for phi in block.phi_nodes.borrow().iter() {
let uniform = phi.entries.iter().all(|entry| {
let block_uniform = self.is_block_uniform(entry.block);
let value_uniform = self.is_var_uniform(entry.value);
block_uniform && value_uniform
}) && block_uniform;
self.mark_uniformity(phi.out, uniform && block_uniform)?;
}

for inst in block.ops.borrow().values() {
if inst.out.is_none() {
continue;
}
let out = inst.out.unwrap();
match &inst.operation {
Operation::Plane(plane) => match plane {
// Elect returns true on only one unit, so it's always non-uniform
Plane::Elect => self.mark_uniformity(out, false)?,
// Reductions are always uniform if executed in uniform control flow
Plane::Sum(_)
| Plane::Prod(_)
| Plane::Min(_)
| Plane::Max(_)
| Plane::All(_)
| Plane::Any(_)
| Plane::Ballot(_) => self.mark_uniformity(out, block_uniform)?,
// Broadcast maps to shuffle or broadcast, if id or value is uniform, so will
// the output, otherwise not.
Plane::Broadcast(op) => {
let input_uniform =
self.is_var_uniform(op.lhs) || self.is_var_uniform(op.rhs);
self.mark_uniformity(out, input_uniform && block_uniform)?;
}
},
Operation::Synchronization(sync) => match sync {
Synchronization::SyncUnits | Synchronization::SyncStorage => {
block_uniform = true;
}
},
op => {
let is_uniform =
op.is_pure() && self.is_all_uniform(op.args()) && block_uniform;
self.mark_uniformity(out, is_uniform)?;
}
}
}

match &*block.control_flow.borrow() {
ControlFlow::IfElse {
cond,
then,
or_else,
merge,
} => {
let is_uniform = self.is_var_uniform(*cond);
self.block_uniformity
.insert(*then, is_uniform && block_uniform);
self.block_uniformity
.insert(*or_else, is_uniform && block_uniform);
if let Some(merge) = merge {
self.block_uniformity.insert(*merge, block_uniform);
}
}
ControlFlow::Switch {
value,
default,
branches,
merge,
} => {
let is_uniform = self.is_var_uniform(*value);
self.block_uniformity
.insert(*default, is_uniform && block_uniform);
for branch in branches {
self.block_uniformity
.insert(branch.1, is_uniform && block_uniform);
}
if let Some(merge) = merge {
self.block_uniformity.insert(*merge, block_uniform);
}
}
ControlFlow::Loop {
body,
continue_target,
merge,
} => {
// If we don't know the break condition, we can't detect whether it's uniform
self.block_uniformity.insert(block_id, false);
self.block_uniformity.insert(*body, false);
self.block_uniformity.insert(*continue_target, false);
self.block_uniformity.insert(*merge, false);
}
ControlFlow::LoopBreak {
break_cond,
body,
continue_target,
merge,
} => {
let is_uniform = self.is_var_uniform(*break_cond);
self.block_uniformity
.insert(block_id, is_uniform && block_uniform);
self.block_uniformity
.insert(*body, is_uniform && block_uniform);
self.block_uniformity
.insert(*continue_target, is_uniform && block_uniform);
self.block_uniformity
.insert(*merge, is_uniform && block_uniform);
}
ControlFlow::Return => {}
ControlFlow::None => {
let successor = opt.successors(block_id)[0];
self.block_uniformity
.entry(successor)
.and_modify(|it| {
*it |= block_uniform;
})
.or_insert(block_uniform);
}
}

for edge in opt.program.edges(block_id) {
if !self.visited.contains(&edge.id()) {
self.visited.insert(edge.id());
self.analyze_block(opt, edge.target())?;
}
}

Some(())
}

fn mark_uniformity(&mut self, var: Variable, value: bool) -> Option<()> {
if let Some(val) = self.variable_uniformity.get_mut(&var) {
// If the value was already set before and has been invalidated, we need to revisit
// all edges. This only happens for loopback edges, where an uninitialized variable
// was assumed to be uniform but actually isn't
let invalidate = !value && *val;
*val = *val && value;
if invalidate {
self.visited.clear();
return None;
}
} else {
self.variable_uniformity.insert(var, value);
}
Some(())
}

fn is_all_uniform(&self, args: Option<Vec<Variable>>) -> bool {
args.map(|it| it.iter().all(|it| self.is_var_uniform(*it)))
.unwrap_or(false)
}

/// Whether a variable is plane uniform
pub fn is_var_uniform(&self, var: Variable) -> bool {
match var.kind {
VariableKind::ConstantArray { .. }
| VariableKind::SharedMemory { .. }
| VariableKind::GlobalInputArray(_)
| VariableKind::GlobalOutputArray(_)
| VariableKind::GlobalScalar(_)
| VariableKind::ConstantScalar(_) => true,

VariableKind::Builtin(builtin) => match builtin {
Builtin::UnitPosPlane
| Builtin::AbsolutePos
| Builtin::AbsolutePosX
| Builtin::AbsolutePosY
| Builtin::AbsolutePosZ
| Builtin::UnitPos
| Builtin::UnitPosX
| Builtin::UnitPosY
| Builtin::UnitPosZ => false,
Builtin::CubePos
| Builtin::CubePosX
| Builtin::CubePosY
| Builtin::CubePosZ
| Builtin::CubeDim
| Builtin::CubeDimX
| Builtin::CubeDimY
| Builtin::CubeDimZ
| Builtin::CubeCount
| Builtin::CubeCountX
| Builtin::CubeCountY
| Builtin::CubeCountZ
| Builtin::PlaneDim => true,
},

VariableKind::LocalMut { .. } => false,

VariableKind::LocalArray { .. }
| VariableKind::LocalConst { .. }
| VariableKind::Versioned { .. }
| VariableKind::Matrix { .. }
| VariableKind::Slice { .. }
| VariableKind::Pipeline { .. } => {
self.variable_uniformity.get(&var).copied().unwrap_or(true)
}
}
}

pub fn is_block_uniform(&self, block: NodeIndex) -> bool {
self.block_uniformity.get(&block).copied().unwrap_or(true)
}
}
4 changes: 1 addition & 3 deletions crates/cubecl-opt/src/control_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,7 @@ impl Optimizer {
true => Comparison::LowerEqual,
false => Comparison::Lower,
};
let tmp = *self
.allocator
.create_local_restricted(Item::new(Elem::Bool));
let tmp = *self.allocator.create_local(Item::new(Elem::Bool));
self.program[header].ops.borrow_mut().push(Instruction::new(
op(BinaryOperator {
lhs: i,
Expand Down
32 changes: 24 additions & 8 deletions crates/cubecl-opt/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use cubecl_ir::{FloatKind, IntKind, UIntKind};
use petgraph::visit::EdgeRef;

use crate::{
analyses::{const_len::Slices, integer_range::Ranges, liveness::Liveness},
analyses::{const_len::Slices, liveness::Liveness, uniformity::Uniformity},
gvn::{BlockSets, Constant, Expression, GvnState, Instruction, Local, Value, ValueTable},
ControlFlow,
};
Expand All @@ -17,7 +17,6 @@ const DEBUG_GVN: bool = false;
impl Display for Optimizer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let slices = self.analysis_cache.try_get::<Slices>().unwrap_or_default();
let ranges = self.analysis_cache.try_get::<Ranges>().unwrap_or_default();

f.write_str("Slices:\n")?;
for (var_id, slice) in slices.iter() {
Expand All @@ -41,6 +40,10 @@ impl Display for Optimizer {
.analysis_cache
.try_get::<Liveness>()
.unwrap_or_else(|| Rc::new(Liveness::empty(self)));
let uniformity = self
.analysis_cache
.try_get::<Uniformity>()
.unwrap_or_default();

if DEBUG_GVN {
writeln!(f, "# Value Table:")?;
Expand All @@ -50,7 +53,11 @@ impl Display for Optimizer {
for node in self.program.node_indices() {
let id = node.index();
let bb = &self.program[node];
writeln!(f, "bb{id} {{")?;
let uniform = match uniformity.is_block_uniform(node) {
true => "uniform ",
false => "",
};
writeln!(f, "{uniform}bb{id} {{")?;
if DEBUG_GVN {
let block_sets = &global_nums
.block_sets
Expand All @@ -74,7 +81,11 @@ impl Display for Optimizer {
write!(f, "[bb{}: ", entry.block.index())?;
write!(f, "{}]", entry.value)?;
}
f.write_str(";\n")?;
let is_uniform = match uniformity.is_var_uniform(phi.out) {
true => " @ uniform",
false => "",
};
writeln!(f, ";{is_uniform}\n")?;
}
if !bb.phi_nodes.borrow().is_empty() {
writeln!(f)?;
Expand All @@ -86,10 +97,15 @@ impl Display for Optimizer {
continue;
}

let range = op.out.map(|var| ranges.range_of(self, &var));
let range = range.map(|it| format!(" range: {it};")).unwrap_or_default();

writeln!(f, " {op_fmt};{range}")?;
let is_uniform = match op
.out
.map(|out| uniformity.is_var_uniform(out))
.unwrap_or(false)
{
true => " @ uniform",
false => "",
};
writeln!(f, " {op_fmt};{is_uniform}")?;
}
match &*bb.control_flow.borrow() {
ControlFlow::IfElse {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-opt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ mod phi_frontiers;
mod transformers;
mod version;

pub use analyses::uniformity::Uniformity;
pub use block::*;
pub use control_flow::*;
pub use petgraph::graph::{EdgeIndex, NodeIndex};
Expand Down
Loading

0 comments on commit 63a1c59

Please sign in to comment.