Skip to content

Commit

Permalink
Add method to merge evaluators
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Aug 31, 2024
1 parent d246a31 commit cbc87b4
Showing 1 changed file with 169 additions and 0 deletions.
169 changes: 169 additions & 0 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,175 @@ impl<T: Default> ExpressionEvaluator<T> {
}
}

impl<T: Default + Clone + Eq + Hash> ExpressionEvaluator<T> {
/// Merge evaluator `other` into `self`. The parameters must be the same.
pub fn merge(&mut self, mut other: Self, cpe_rounds: Option<usize>) -> Result<(), String> {
if self.param_count != other.param_count {
return Err("Parameter count is different".to_owned());
}

let mut constants = HashMap::default();

for (i, c) in self.stack[self.param_count..self.reserved_indices]
.iter()
.enumerate()
{
constants.insert(c.clone(), i);
}

let old_len = self.stack.len() - self.reserved_indices;

self.stack.truncate(self.reserved_indices);

for c in &other.stack[self.param_count..other.reserved_indices] {
if constants.get(c).is_none() {
let i = constants.len();
constants.insert(c.clone(), i);
self.stack.push(c.clone());
}
}

let new_reserved_indices = self.stack.len();
let mut delta = new_reserved_indices - self.reserved_indices;

// shift stack indices
if delta > 0 {
for i in &mut self.instructions {
match i {
Instr::Add(r, a) | Instr::Mul(r, a) => {
*r += delta;
for aa in a {
if *aa >= self.reserved_indices {
*aa += delta;
}
}
}
Instr::Pow(r, b, _) | Instr::BuiltinFun(r, _, b) => {
*r += delta;
if *b >= self.reserved_indices {
*b += delta;
}
}
Instr::Powf(r, b, e) => {
*r += delta;
if *b >= self.reserved_indices {
*b += delta;
}
if *e >= self.reserved_indices {
*e += delta;
}
}
}
}

for x in &mut self.result_indices {
*x += delta;
}
}

delta = old_len + new_reserved_indices - other.reserved_indices;
for i in &mut other.instructions {
match i {
Instr::Add(r, a) | Instr::Mul(r, a) => {
*r += delta;
for aa in a {
if *aa >= other.reserved_indices {
*aa += delta;
} else if *aa >= other.param_count {
*aa = self.param_count + constants[&other.stack[*aa]];
}
}
}
Instr::Pow(r, b, _) | Instr::BuiltinFun(r, _, b) => {
*r += delta;
if *b >= other.reserved_indices {
*b += delta;
} else if *b >= other.param_count {
*b = self.param_count + constants[&other.stack[*b]];
}
}
Instr::Powf(r, b, e) => {
*r += delta;
if *b >= other.reserved_indices {
*b += delta;
} else if *b >= other.param_count {
*b = self.param_count + constants[&other.stack[*b]];
}
if *e >= other.reserved_indices {
*e += delta;
} else if *e >= other.param_count {
*e = self.param_count + constants[&other.stack[*e]];
}
}
}
}

for x in &mut other.result_indices {
if *x >= other.reserved_indices {
*x += delta;
} else if *x >= other.param_count {
*x = self.param_count + constants[&other.stack[*x]];
}
}

self.instructions.extend(other.instructions.drain(..));
self.result_indices.extend(other.result_indices.drain(..));
self.reserved_indices = new_reserved_indices;

// undo the stack optimization
let mut unfold = HashMap::default();
for (index, i) in &mut self.instructions.iter_mut().enumerate() {
match i {
Instr::Add(r, a) | Instr::Mul(r, a) => {
for aa in a {
if *aa >= self.reserved_indices {
*aa = unfold[aa];
}
}

unfold.insert(*r, index + self.reserved_indices);
*r = index + self.reserved_indices;
}
Instr::Pow(r, b, _) | Instr::BuiltinFun(r, _, b) => {
if *b >= self.reserved_indices {
*b = unfold[b];
}
unfold.insert(*r, index + self.reserved_indices);
*r = index + self.reserved_indices;
}
Instr::Powf(r, b, e) => {
if *b >= self.reserved_indices {
*b = unfold[b];
}
if *e >= self.reserved_indices {
*e = unfold[e];
}
unfold.insert(*r, index + self.reserved_indices);
*r = index + self.reserved_indices;
}
}
}

for i in &mut self.result_indices {
*i = unfold[i];
}

for _ in 0..self.instructions.len() {
self.stack.push(T::default());
}

for _ in 0..cpe_rounds.unwrap_or(usize::MAX) {
if self.remove_common_pairs() == 0 {
break;
}
}

self.optimize_stack();

Ok(())
}
}

impl<T> ExpressionEvaluator<T> {
pub fn optimize_stack(&mut self) {
let mut last_use: Vec<usize> = vec![0; self.stack.len()];
Expand Down

0 comments on commit cbc87b4

Please sign in to comment.