Skip to content

Commit

Permalink
Fix tagged function system
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Oct 30, 2024
1 parent 04c98f8 commit 158f799
Showing 1 changed file with 14 additions and 37 deletions.
51 changes: 14 additions & 37 deletions src/evaluate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use self_cell::self_cell;
use serde::{Deserialize, Deserializer, Serialize, Serializer};

use crate::{
atom::{representation::InlineVar, Atom, AtomOrView, AtomView, Symbol},
atom::{Atom, AtomOrView, AtomView, Symbol},
coefficient::CoefficientView,
combinatorics::unique_permutations,
domains::{
Expand Down Expand Up @@ -87,7 +87,7 @@ impl<'a, T> FunctionMap<'a, T> {
}

self.map.insert(
AtomOrTaggedFunction::Atom(Atom::new_var(name).into()),
AtomOrTaggedFunction::TaggedFunction(name, vec![]),
ConstOrExpr::Expr(rename, 0, args, body),
);

Expand Down Expand Up @@ -121,6 +121,13 @@ impl<'a, T> FunctionMap<'a, T> {
self.tag.get(symbol).cloned().unwrap_or(0)
}

fn get_constant(&self, a: AtomView<'a>) -> Option<&T> {
match self.map.get(&AtomOrTaggedFunction::Atom(a.into())) {
Some(ConstOrExpr::Const(c)) => Some(c),
_ => None,
}
}

fn get(&self, a: AtomView<'a>) -> Option<&ConstOrExpr<'a, T>> {
if let Some(c) = self.map.get(&AtomOrTaggedFunction::Atom(a.into())) {
return Some(c);
Expand All @@ -130,7 +137,7 @@ impl<'a, T> FunctionMap<'a, T> {
let s = aa.get_symbol();
let tag_len = self.get_tag_len(&s);

if tag_len != 0 && aa.get_nargs() >= tag_len {
if aa.get_nargs() >= tag_len {
let tag = aa.iter().take(tag_len).map(|x| x.into()).collect();
return self.map.get(&AtomOrTaggedFunction::TaggedFunction(s, tag));
}
Expand Down Expand Up @@ -3890,34 +3897,8 @@ impl<'a> AtomView<'a> {
return Ok(Expression::Parameter(p));
}

if let Some(c) = fn_map.get(*self) {
return match c {
ConstOrExpr::Const(c) => Ok(Expression::Const(c.clone())),
ConstOrExpr::Expr(name, _tag_len, args, v) => {
if args.len() != 0 {
return Err(format!(
"Function {} called with wrong number of arguments: 0 vs {}",
self,
args.len()
));
}

if let Some(pos) = funcs.iter().position(|f| f.0 == *name) {
Ok(Expression::Eval(pos, vec![]))
} else {
let r = v.to_eval_tree_impl(fn_map, params, args, funcs)?;
funcs.push((
name.clone(),
args.clone(),
SplitExpression {
tree: vec![r.clone()],
subexpressions: vec![],
},
));
Ok(Expression::Eval(funcs.len() - 1, vec![]))
}
}
};
if let Some(c) = fn_map.get_constant(*self) {
return Ok(Expression::Const(c.clone()));
}

match self {
Expand Down Expand Up @@ -3960,12 +3941,8 @@ impl<'a> AtomView<'a> {
));
}

let symb = InlineVar::new(f.get_symbol());
let Some(fun) = fn_map.get(symb.as_view()) else {
return Err(format!(
"Undefined function {}",
State::get_name(f.get_symbol())
));
let Some(fun) = fn_map.get(*self) else {
return Err(format!("Undefined function {}", self));
};

match fun {
Expand Down

0 comments on commit 158f799

Please sign in to comment.