Skip to content

Commit

Permalink
Apply a Horner scheme to the evaluation tree
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Jun 29, 2024
1 parent 1bb9d30 commit 3a12ea6
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 2 deletions.
7 changes: 5 additions & 2 deletions examples/nested_evaluation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() {
let g = Atom::parse("i(y+7)+x*i(y+7)*(y-1)").unwrap();
let h = Atom::parse("y*(1+x*(1+x^2)) + y^2*(1+x*(1+x^2))^2 + 3*(1+x^2)").unwrap();
let i = Atom::parse("y - 1").unwrap();
let k = Atom::parse("x+8").unwrap();
let k = Atom::parse("3*x^3 + 4*x^2 + 6*x +8").unwrap();

let mut const_map = HashMap::default();

Expand Down Expand Up @@ -63,7 +63,10 @@ fn main() {

let params = vec![Atom::parse("x").unwrap()];

let tree = e.as_view().to_eval_tree(|r| r.clone(), &const_map, &params);
let mut tree = e.as_view().to_eval_tree(|r| r.clone(), &const_map, &params);

tree.horner_scheme(); // optimize the tree using an occurrence-order Horner scheme

let t2 = tree.map_coeff::<f64, _>(&|r| r.into());
println!("{}", t2.export_cpp()); // print C++ code

Expand Down
149 changes: 149 additions & 0 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,155 @@ impl<T: Clone + Default + PartialEq> EvalTree<T> {
}
}

impl EvalTree<Rational> {
fn apply_horner_scheme(&mut self, scheme: &[EvalTree<Rational>]) {
if scheme.is_empty() {
return;
}

let EvalTree::Add(a) = self else {
return;
};

// TODO: find power to extract, now we do just one

let mut contains = vec![];
let mut rest = vec![];

for x in a {
let mut found = false;
if let EvalTree::Mul(m) = x {
for (p, y) in m.iter_mut().enumerate() {
if let EvalTree::Pow(p) = y {
if p.0 == scheme[0] {
found = true;
if p.1 == 2 {
*y = p.0.clone(); // TODO: prevent clone
} else {
p.1 -= 1;
}
}
} else if y == &scheme[0] {
found = true;
// remove from prod
m.remove(p);
if m.len() == 1 {
*x = m[0].clone();
}
break;
}
}
} else if x == &scheme[0] {
found = true;
*x = EvalTree::Const(Rational::one());
}

if found {
contains.push(x.clone());
} else {
rest.push(x.clone());
}
}

if contains.is_empty() {
*self = EvalTree::Add(rest);
self.apply_horner_scheme(&scheme[1..]);
} else {
let mut c = EvalTree::Mul(vec![EvalTree::Add(contains), scheme[0].clone()]);
c.apply_horner_scheme(&scheme[1..]);

if rest.is_empty() {
*self = c;
} else {
let mut r = EvalTree::Add(rest);
r.apply_horner_scheme(&scheme[1..]);

*self = EvalTree::Add(vec![c, r]);
}
}
}

/// Apply a simple occurrence-order Horner scheme to every addition.
pub fn horner_scheme(&mut self) {
match self {
EvalTree::Const(_) | EvalTree::Parameter(_) | EvalTree::ReadArg(_, _) => {}
EvalTree::Eval(_, _, _, ae, f) => {
for arg in ae {
arg.horner_scheme();
}
f.horner_scheme();
}
EvalTree::Add(a) => {
for arg in &mut *a {
arg.horner_scheme();
}

let mut occurrence = HashMap::default();

for arg in &*a {
match arg {
EvalTree::Mul(m) => {
for aa in m {
if let EvalTree::Pow(p) = aa {
occurrence
.entry(p.0.clone())
.and_modify(|x| *x += 1)
.or_insert(1);
} else {
occurrence
.entry(aa.clone())
.and_modify(|x| *x += 1)
.or_insert(1);
}
}
}
x => {
if let EvalTree::Pow(p) = x {
occurrence
.entry(p.0.clone())
.and_modify(|x| *x += 1)
.or_insert(1);
} else {
occurrence
.entry(x.clone())
.and_modify(|x| *x += 1)
.or_insert(1);
}
}
}
}

occurrence.retain(|_, v| *v > 1);
let mut order: Vec<_> = occurrence.into_iter().collect();
order.sort_by_key(|k| k.1);
let scheme = order.into_iter().map(|(k, _)| k).collect::<Vec<_>>();

self.apply_horner_scheme(&scheme);
}
EvalTree::Mul(a) => {
for arg in a {
arg.horner_scheme();
}
}
EvalTree::Pow(p) => {
p.0.horner_scheme();
}
EvalTree::Powf(p) => {
p.0.horner_scheme();
p.1.horner_scheme();
}
EvalTree::BuiltinFun(_, a) => {
a.horner_scheme();
}
EvalTree::SubExpression(_, r) => {
let mut rr = r.as_ref().clone();
rr.horner_scheme();
*r = Rc::new(rr);
}
}
}
}

impl<T: Clone + Default + Eq + std::hash::Hash> EvalTree<T> {
fn extract_subexpressions(&mut self) {
let mut h = HashMap::default();
Expand Down

0 comments on commit 3a12ea6

Please sign in to comment.