Skip to content

Commit

Permalink
Improve performance of expression addition
Browse files Browse the repository at this point in the history
  • Loading branch information
benruijl committed Oct 10, 2024
1 parent a6e3322 commit 280c2a6
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,15 @@ impl<'a> AtomView<'a> {
format!("{}", self.printer(PrintOptions::file()))
}

/// Get the number of terms.
pub fn nterms(&self) -> usize {
if let AtomView::Add(a) = self {
a.get_nargs()
} else {
1
}
}

/// Print statistics about the operation `op`, such as its duration and term growth.
pub fn with_stats<F: Fn(AtomView) -> Atom>(&self, op: F, o: &StatsOptions) -> Atom {
let t = std::time::Instant::now();
Expand Down Expand Up @@ -379,17 +388,6 @@ impl<'a> AtomView<'a> {
}
}

/// Add two atoms and return the buffer that contains the unnormalized result.
fn add_no_norm(&self, workspace: &Workspace, rhs: AtomView<'_>) -> RecycledAtom {
let mut e = workspace.new_atom();
let a = e.to_add();

// TODO: check if self or rhs is add
a.extend(*self);
a.extend(rhs);
e
}

/// Subtract two atoms and return the buffer that contains the unnormalized result.
fn sub_no_norm(&self, workspace: &Workspace, rhs: AtomView<'_>) -> RecycledAtom {
let mut e = workspace.new_atom();
Expand Down Expand Up @@ -435,9 +433,7 @@ impl<'a> AtomView<'a> {

/// Add `self` and `rhs`, writing the result in `out`.
pub fn add_with_ws_into(&self, workspace: &Workspace, rhs: AtomView<'_>, out: &mut Atom) {
self.add_no_norm(workspace, rhs)
.as_view()
.normalize(workspace, out);
self.add_normalized(rhs, workspace, out);
}

/// Subtract `rhs` from `self, writing the result in `out`.
Expand Down Expand Up @@ -631,6 +627,10 @@ impl Atom {
self.as_view().is_one()
}

pub fn nterms(&self) -> usize {
self.as_view().nterms()
}

/// Print the atom using the portable [`PrintOptions::file()`] options.
pub fn to_string(&self) -> String {
format!("{}", self.printer(PrintOptions::file()))
Expand Down
7 changes: 7 additions & 0 deletions src/atom/representation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,13 @@ impl Add {
pub(crate) unsafe fn from_raw(raw: RawAtom) -> Add {
Add { data: raw }
}

pub(crate) fn grow_capacity(&mut self, size: usize) {
if size > self.data.capacity() {
let additional = size - self.data.capacity();
self.data.reserve(additional);
}
}
}

impl<'a> VarView<'a> {
Expand Down
126 changes: 126 additions & 0 deletions src/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,117 @@ impl<'a> AtomView<'a> {
}
}
}

/// Add two atoms and normalize the result.
pub(crate) fn add_normalized(&self, rhs: AtomView, ws: &Workspace, out: &mut Atom) {
let a = out.to_add();
a.grow_capacity(self.get_byte_size() + rhs.get_byte_size());

let mut helper = ws.new_atom();
let mut b = ws.new_atom();
if let AtomView::Add(a1) = self {
if let AtomView::Add(a2) = rhs {
let mut s = a1.iter();
let mut t = a2.iter();

let mut curs = s.next();
let mut curst = t.next();
while curs.is_some() || curst.is_some() {
if let Some(ss) = curs {
if let Some(tt) = curst {
match ss.cmp_terms(&tt) {
Ordering::Less => {
a.extend(ss);
curs = s.next();
}
Ordering::Greater => {
a.extend(tt);
curst = t.next();
}
Ordering::Equal => {
b.set_from_view(&ss);
if b.merge_terms(tt, &mut helper) {
if let AtomView::Num(n) = a.as_view() {
if !n.is_zero() {
a.extend(b.as_view());
}
} else {
a.extend(b.as_view());
}
} else {
unreachable!("Equal terms do not merge");
}

curst = t.next();
curs = s.next();
}
}
} else {
a.extend(ss);
curs = s.next();
}
} else if let Some(tt) = curst {
a.extend(tt);
curst = t.next();
}
}

a.set_normalized(true);
return;
}
}

if let AtomView::Add(a1) = self {
let mut found = false;
for x in a1.iter() {
// TODO: find the position of rhs in self with a binary search
if found {
a.extend(x);
continue;
}

match x.cmp_terms(&rhs) {
Ordering::Less => {
a.extend(x);
}
Ordering::Equal => {
found = true;
b.set_from_view(&x);
if b.merge_terms(rhs, &mut helper) {
if let AtomView::Num(n) = a.as_view() {
if !n.is_zero() {
a.extend(b.as_view());
}
} else {
a.extend(b.as_view());
}
} else {
unreachable!("Equal terms do not merge");
}
}
Ordering::Greater => {
found = true;
a.extend(rhs);
a.extend(x);
}
}
}

if !found {
a.extend(rhs);
}

a.set_normalized(true);
} else if let AtomView::Add(_) = rhs {
rhs.add_normalized(*self, ws, out);
} else {
let mut e = ws.new_atom();
let a = e.to_add();
a.extend(*self);
a.extend(rhs);
e.as_view().normalize(ws, out);
}
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1421,4 +1532,19 @@ mod test {
panic!("Expected a Mul");
}
}

#[test]
fn add_normalized() {
let a = Atom::parse("v1 + v2 + v3").unwrap();
let b = Atom::parse("1 + v2 + v4 + v5").unwrap();
assert_eq!(a + b, Atom::parse("v1+2*v2+v3+v4+v5+1").unwrap());

let a = Atom::parse("v1 + v2 + v3").unwrap();
let b = Atom::parse("v4").unwrap();
assert_eq!(a + b, Atom::parse("v1+v2+v3+v4").unwrap());

let a = Atom::parse("v2 + v3 + v4").unwrap();
let b = Atom::parse("v1").unwrap();
assert_eq!(a + b, Atom::parse("v1+v2+v3+v4").unwrap());
}
}

0 comments on commit 280c2a6

Please sign in to comment.