From f53c41176741418b3cda0f5c350f2cac68b02d83 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Fri, 9 Aug 2024 11:52:45 +0200 Subject: [PATCH] Perfom CPE on linearized evaluator --- examples/nested_evaluation.rs | 8 +- src/evaluate.rs | 239 +++++++++++++++++++++++++++++++++- 2 files changed, 242 insertions(+), 5 deletions(-) diff --git a/examples/nested_evaluation.rs b/examples/nested_evaluation.rs index ab9a07c..297d9ad 100644 --- a/examples/nested_evaluation.rs +++ b/examples/nested_evaluation.rs @@ -78,11 +78,11 @@ fn main() { println!("Op original {:?}", tree.count_operations()); tree.horner_scheme(); println!("Op horner {:?}", tree.count_operations()); - tree.common_subexpression_elimination(20); // test 2^20 options at the most + tree.common_subexpression_elimination(0); println!("op cse {:?}", tree.count_operations()); - tree.common_pair_elimination(); - println!("op cpe {:?}", tree.count_operations()); + //tree.common_pair_elimination(); + //println!("op cpe {:?}", tree.count_operations()); let ce = tree .export_cpp("nested_evaluation.cpp", "evaltest", true) @@ -112,7 +112,7 @@ fn main() { println!("C++ time {:#?}", t.elapsed()); let t2 = tree.map_coeff::(&|r| r.into()); - let mut evaluator: ExpressionEvaluator = t2.linearize(); + let mut evaluator: ExpressionEvaluator = t2.linearize(10); evaluator.evaluate_multiple(¶ms, &mut out); println!("Eval: {}, {}", out[0], out[1]); diff --git a/src/evaluate.rs b/src/evaluate.rs index d94b8a9..2e70b95 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -734,6 +734,237 @@ impl ExpressionEvaluator { } } +impl ExpressionEvaluator { + pub fn remove_common_pairs(&mut self) -> usize { + let mut pairs: HashMap<_, Vec> = HashMap::default(); + + let mut affected_lines = vec![true; self.instructions.len()]; + + for (p, i) in self.instructions.iter().enumerate() { + match i { + Instr::Add(_, a) | Instr::Mul(_, a) => { + let is_add = matches!(i, Instr::Add(_, _)); + for (li, l) in a.iter().enumerate() { + for r in &a[li + 1..] { + pairs.entry((is_add, *l, *r)).or_default().push(p); + } + } + } + _ => {} + } + } + + // for now, ignore pairs with only occurrences on the same line + let mut to_remove: Vec<_> = pairs.clone().into_iter().collect(); + + to_remove.retain_mut(|(_, v)| { + v.dedup(); + v.len() > 1 + }); + + // sort in other direction since we pop + to_remove.sort_by_key(|x| x.1.len()); + + let total_remove = to_remove.len(); + + for x in &mut affected_lines { + *x = false; + } + + let old_len = self.instructions.len(); + + while let Some(((is_add, l, r), lines)) = to_remove.pop() { + if lines.iter().any(|x| affected_lines[*x]) { + continue; + } + + let new_idx = self.stack.len(); + let new_op = if is_add { + Instr::Add(new_idx, vec![l, r]) + } else { + Instr::Mul(new_idx, vec![l, r]) + }; + + self.stack.push(T::default()); + self.instructions.push(new_op); + + for line in lines { + affected_lines[line] = true; + let is_add = matches!(self.instructions[line], Instr::Add(_, _)); + + if let Instr::Add(_, a) | Instr::Mul(_, a) = &mut self.instructions[line] { + for (li, l) in a.iter().enumerate() { + for r in &a[li + 1..] { + let pp = pairs.entry((is_add, *l, *r)).or_default(); + pp.retain(|x| *x != line); + } + } + + if l == r { + let count = a.iter().filter(|x| **x == l).count(); + let pairs = count / 2; + if pairs > 0 { + a.retain(|x| *x != l); + + if count % 2 == 1 { + a.push(l.clone()); + } + + a.extend(std::iter::repeat(new_idx).take(pairs)); + a.sort(); + } + } else { + let mut idx1_count = 0; + let mut idx2_count = 0; + for v in &*a { + if *v == l { + idx1_count += 1; + } + if *v == r { + idx2_count += 1; + } + } + + let pair_count = idx1_count.min(idx2_count); + + if pair_count > 0 { + a.retain(|x| *x != l && *x != r); + + // add back removed indices in cases such as idx1*idx2*idx2 + if idx1_count > pair_count { + a.extend( + std::iter::repeat(l.clone()).take(idx1_count - pair_count), + ); + } + if idx2_count > pair_count { + a.extend( + std::iter::repeat(r.clone()).take(idx2_count - pair_count), + ); + } + + a.extend(std::iter::repeat(new_idx).take(pair_count)); + a.sort(); + } + } + + // update the pairs for this line + for (li, l) in a.iter().enumerate() { + for r in &a[li + 1..] { + pairs.entry((is_add, *l, *r)).or_default().push(line); + } + } + } + } + } + + let mut first_use = vec![]; + for i in self.instructions.drain(old_len..) { + if let Instr::Add(_, a) | Instr::Mul(_, a) = &i { + let mut last_dep = a[0]; + for v in a { + last_dep = last_dep.max(*v); + } + + let ins = if last_dep + 1 <= self.reserved_indices { + 0 + } else { + last_dep + 1 - self.reserved_indices + }; + + first_use.push((ins, i)); + } else { + unreachable!() + } + } + + first_use.sort_by_key(|x| x.0); + + let mut new_instr = vec![]; + let mut i = 0; + let mut j = 0; + + let mut sub_rename = HashMap::default(); + let mut rename_map: Vec<_> = (0..self.reserved_indices).collect(); + + macro_rules! rename { + ($i:expr) => { + if $i >= self.reserved_indices + self.instructions.len() { + sub_rename[&$i] + } else { + rename_map[$i] + } + }; + } + + while i < self.instructions.len() { + let new_pos = new_instr.len() + self.reserved_indices; + + if j < first_use.len() && i == first_use[j].0 { + let (o, a) = match &first_use[j].1 { + Instr::Add(o, a) => (*o, a), + Instr::Mul(o, a) => (*o, a), + _ => unreachable!(), + }; + + let is_add = matches!(&first_use[j].1, Instr::Add(_, _)); + + let new_a = a.iter().map(|x| rename!(*x)).collect::>(); + + if is_add { + new_instr.push(Instr::Add(new_pos, new_a)); + } else { + new_instr.push(Instr::Mul(new_pos, new_a)); + } + + sub_rename.insert(o, new_pos); + + j += 1; + } else { + let mut s = self.instructions[i].clone(); + + match &mut s { + Instr::Add(p, a) | Instr::Mul(p, a) => { + for x in &mut *a { + *x = rename!(*x); + } + + // remove assignments + if a.len() == 1 { + rename_map.push(a[0]); + i += 1; + continue; + } + + *p = new_pos; + } + Instr::Pow(p, b, _) | Instr::BuiltinFun(p, _, b) => { + *b = rename!(*b); + *p = new_pos; + } + Instr::Powf(p, a, b) => { + *a = rename!(*a); + *b = rename!(*b); + *p = new_pos; + } + } + + new_instr.push(s); + rename_map.push(new_pos); + i += 1; + } + } + + for x in &mut self.result_indices { + *x = rename!(*x); + } + + assert!(j == first_use.len()); + + self.instructions = new_instr; + total_remove + } +} + impl ExpressionEvaluator { pub fn optimize_stack(&mut self) { let mut last_use: Vec = vec![0; self.stack.len()]; @@ -1439,7 +1670,7 @@ impl EvalTree { impl EvalTree { /// Create a linear version of the tree that can be evaluated more efficiently. - pub fn linearize(mut self) -> ExpressionEvaluator { + pub fn linearize(mut self, cpe_rounds: usize) -> ExpressionEvaluator { let mut stack = vec![T::default(); self.param_count]; // strip every constant and move them into the stack after the params @@ -1471,6 +1702,12 @@ impl EvalTree { result_indices, }; + for _ in 0..cpe_rounds { + if e.remove_common_pairs() == 0 { + break; + } + } + e.optimize_stack(); e }