Skip to content

Commit

Permalink
Perfom CPE on linearized evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Aug 9, 2024
1 parent 654df3f commit f53c411
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 5 deletions.
8 changes: 4 additions & 4 deletions examples/nested_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ fn main() {
println!("Op original {:?}", tree.count_operations());
tree.horner_scheme();
println!("Op horner {:?}", tree.count_operations());
tree.common_subexpression_elimination(20); // test 2^20 options at the most
tree.common_subexpression_elimination(0);
println!("op cse {:?}", tree.count_operations());

tree.common_pair_elimination();
println!("op cpe {:?}", tree.count_operations());
//tree.common_pair_elimination();
//println!("op cpe {:?}", tree.count_operations());

let ce = tree
.export_cpp("nested_evaluation.cpp", "evaltest", true)
Expand Down Expand Up @@ -112,7 +112,7 @@ fn main() {
println!("C++ time {:#?}", t.elapsed());

let t2 = tree.map_coeff::<f64, _>(&|r| r.into());
let mut evaluator: ExpressionEvaluator<f64> = t2.linearize();
let mut evaluator: ExpressionEvaluator<f64> = t2.linearize(10);

evaluator.evaluate_multiple(&params, &mut out);
println!("Eval: {}, {}", out[0], out[1]);
Expand Down
239 changes: 238 additions & 1 deletion src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,237 @@ impl<T: Real> ExpressionEvaluator<T> {
}
}

impl<T: Default> ExpressionEvaluator<T> {
pub fn remove_common_pairs(&mut self) -> usize {
let mut pairs: HashMap<_, Vec<usize>> = HashMap::default();

let mut affected_lines = vec![true; self.instructions.len()];

for (p, i) in self.instructions.iter().enumerate() {
match i {
Instr::Add(_, a) | Instr::Mul(_, a) => {
let is_add = matches!(i, Instr::Add(_, _));
for (li, l) in a.iter().enumerate() {
for r in &a[li + 1..] {
pairs.entry((is_add, *l, *r)).or_default().push(p);
}
}
}
_ => {}
}
}

// for now, ignore pairs with only occurrences on the same line
let mut to_remove: Vec<_> = pairs.clone().into_iter().collect();

to_remove.retain_mut(|(_, v)| {
v.dedup();
v.len() > 1
});

// sort in other direction since we pop
to_remove.sort_by_key(|x| x.1.len());

let total_remove = to_remove.len();

for x in &mut affected_lines {
*x = false;
}

let old_len = self.instructions.len();

while let Some(((is_add, l, r), lines)) = to_remove.pop() {
if lines.iter().any(|x| affected_lines[*x]) {
continue;
}

let new_idx = self.stack.len();
let new_op = if is_add {
Instr::Add(new_idx, vec![l, r])
} else {
Instr::Mul(new_idx, vec![l, r])
};

self.stack.push(T::default());
self.instructions.push(new_op);

for line in lines {
affected_lines[line] = true;
let is_add = matches!(self.instructions[line], Instr::Add(_, _));

if let Instr::Add(_, a) | Instr::Mul(_, a) = &mut self.instructions[line] {
for (li, l) in a.iter().enumerate() {
for r in &a[li + 1..] {
let pp = pairs.entry((is_add, *l, *r)).or_default();
pp.retain(|x| *x != line);
}
}

if l == r {
let count = a.iter().filter(|x| **x == l).count();
let pairs = count / 2;
if pairs > 0 {
a.retain(|x| *x != l);

if count % 2 == 1 {
a.push(l.clone());
}

a.extend(std::iter::repeat(new_idx).take(pairs));
a.sort();
}
} else {
let mut idx1_count = 0;
let mut idx2_count = 0;
for v in &*a {
if *v == l {
idx1_count += 1;
}
if *v == r {
idx2_count += 1;
}
}

let pair_count = idx1_count.min(idx2_count);

if pair_count > 0 {
a.retain(|x| *x != l && *x != r);

// add back removed indices in cases such as idx1*idx2*idx2
if idx1_count > pair_count {
a.extend(
std::iter::repeat(l.clone()).take(idx1_count - pair_count),
);
}
if idx2_count > pair_count {
a.extend(
std::iter::repeat(r.clone()).take(idx2_count - pair_count),
);
}

a.extend(std::iter::repeat(new_idx).take(pair_count));
a.sort();
}
}

// update the pairs for this line
for (li, l) in a.iter().enumerate() {
for r in &a[li + 1..] {
pairs.entry((is_add, *l, *r)).or_default().push(line);
}
}
}
}
}

let mut first_use = vec![];
for i in self.instructions.drain(old_len..) {
if let Instr::Add(_, a) | Instr::Mul(_, a) = &i {
let mut last_dep = a[0];
for v in a {
last_dep = last_dep.max(*v);
}

let ins = if last_dep + 1 <= self.reserved_indices {
0
} else {
last_dep + 1 - self.reserved_indices
};

first_use.push((ins, i));
} else {
unreachable!()
}
}

first_use.sort_by_key(|x| x.0);

let mut new_instr = vec![];
let mut i = 0;
let mut j = 0;

let mut sub_rename = HashMap::default();
let mut rename_map: Vec<_> = (0..self.reserved_indices).collect();

macro_rules! rename {
($i:expr) => {
if $i >= self.reserved_indices + self.instructions.len() {
sub_rename[&$i]
} else {
rename_map[$i]
}
};
}

while i < self.instructions.len() {
let new_pos = new_instr.len() + self.reserved_indices;

if j < first_use.len() && i == first_use[j].0 {
let (o, a) = match &first_use[j].1 {
Instr::Add(o, a) => (*o, a),
Instr::Mul(o, a) => (*o, a),
_ => unreachable!(),
};

let is_add = matches!(&first_use[j].1, Instr::Add(_, _));

let new_a = a.iter().map(|x| rename!(*x)).collect::<Vec<_>>();

if is_add {
new_instr.push(Instr::Add(new_pos, new_a));
} else {
new_instr.push(Instr::Mul(new_pos, new_a));
}

sub_rename.insert(o, new_pos);

j += 1;
} else {
let mut s = self.instructions[i].clone();

match &mut s {
Instr::Add(p, a) | Instr::Mul(p, a) => {
for x in &mut *a {
*x = rename!(*x);
}

// remove assignments
if a.len() == 1 {
rename_map.push(a[0]);
i += 1;
continue;
}

*p = new_pos;
}
Instr::Pow(p, b, _) | Instr::BuiltinFun(p, _, b) => {
*b = rename!(*b);
*p = new_pos;
}
Instr::Powf(p, a, b) => {
*a = rename!(*a);
*b = rename!(*b);
*p = new_pos;
}
}

new_instr.push(s);
rename_map.push(new_pos);
i += 1;
}
}

for x in &mut self.result_indices {
*x = rename!(*x);
}

assert!(j == first_use.len());

self.instructions = new_instr;
total_remove
}
}

impl<T> ExpressionEvaluator<T> {
pub fn optimize_stack(&mut self) {
let mut last_use: Vec<usize> = vec![0; self.stack.len()];
Expand Down Expand Up @@ -1439,7 +1670,7 @@ impl<T: Clone + PartialEq> EvalTree<T> {

impl<T: Clone + Default + PartialEq> EvalTree<T> {
/// Create a linear version of the tree that can be evaluated more efficiently.
pub fn linearize(mut self) -> ExpressionEvaluator<T> {
pub fn linearize(mut self, cpe_rounds: usize) -> ExpressionEvaluator<T> {
let mut stack = vec![T::default(); self.param_count];

// strip every constant and move them into the stack after the params
Expand Down Expand Up @@ -1471,6 +1702,12 @@ impl<T: Clone + Default + PartialEq> EvalTree<T> {
result_indices,
};

for _ in 0..cpe_rounds {
if e.remove_common_pairs() == 0 {
break;
}
}

e.optimize_stack();
e
}
Expand Down

0 comments on commit f53c411

Please sign in to comment.