From 25f42472089e8f02fa3e55578ab2d16e7ad06d1f Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Wed, 8 May 2024 17:20:17 -0500 Subject: [PATCH] Switched symbolic to use egg --- Cargo.toml | 3 +- crates/luminal_metal/src/prim.rs | 10 +- src/shape/symbolic.rs | 362 ++++++++++++++++++++++++------- 3 files changed, 285 insertions(+), 90 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 51c239c4..12c1ad9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,8 +22,7 @@ regex = "1.9.5" rustc-hash = "1.1.0" uuid = { version = "1.7.0", features = ["v4"] } as-any = "0.3.1" -#cas-compute = {git="https://github.com/ElectrifyPro/cas-rs", branch="dev"} -#cas-parser = {git="https://github.com/ElectrifyPro/cas-rs", branch="dev"} +egg = "0.9.5" [dev-dependencies] dfdx = { version = "0.13", features = ["f16"] } diff --git a/crates/luminal_metal/src/prim.rs b/crates/luminal_metal/src/prim.rs index 7e9f8349..64d6dbd8 100644 --- a/crates/luminal_metal/src/prim.rs +++ b/crates/luminal_metal/src/prim.rs @@ -1653,10 +1653,12 @@ impl Compiler for PrimitiveCompiler { .neighbors_directed(output_node, petgraph::Direction::Incoming) .next() .unwrap(); - graph.no_delete.remove(&output_node); - graph.no_delete.insert(src); - let w = graph.to_retrieve.remove(&output_node).unwrap(); - graph.to_retrieve.insert(src, w); + if graph.no_delete.remove(&output_node) { + graph.no_delete.insert(src); + } + if let Some(w) = graph.to_retrieve.remove(&output_node) { + graph.to_retrieve.insert(src, w); + } } else { // Create copy node let copy_node = graph diff --git a/src/shape/symbolic.rs b/src/shape/symbolic.rs index 2ab1fcf4..139e98d7 100644 --- a/src/shape/symbolic.rs +++ b/src/shape/symbolic.rs @@ -1,3 +1,4 @@ +use egg::*; use std::{ fmt::Debug, ops::{ @@ -144,7 +145,7 @@ where /// A symbolic expression #[derive(Clone, Copy, Hash, Eq)] pub struct GenericExpression { - pub terms: S, + pub terms: S, // Terms in postfix notation } impl PartialEq for GenericExpression @@ -202,23 +203,7 @@ impl std::fmt::Display for GenericExpression { impl GenericExpression { /// Simplify the expression to its minimal terms pub fn simplify(self) -> Self { - self - // // cas-rs doesn't support some ops - // if self.terms.clone().into_iter().any(|i| { - // matches!( - // i, - // Term::Mod | Term::Max | Term::Min | Term::And | Term::Or | Term::Gte | Term::Lt - // ) - // }) { - // return self; - // } - // let str = format!("{self:?}"); - // let mut parser = Parser::new(&str); - // let ast_expr = parser.try_parse_full::().unwrap(); - // let simplified = simplify(&ast_expr.into()); - // let mut storage = S::default(); - // storage.extend(cas_expr_to_luminal_expr(&simplified)); - // Self { terms: storage } + egg_simplify(self) } pub fn as_num(&self) -> Option { @@ -676,67 +661,279 @@ impl> BitOrAssign for GenericExpression Vec { -// match expr { -// Expr::Primary(Primary::Symbol(symb)) => vec![Term::Var(symb.chars().next().unwrap())], -// Expr::Primary(Primary::Integer(num)) => vec![Term::Num(num.to_i32().unwrap())], -// Expr::Add(terms) => { -// let mut result = Vec::new(); -// for term in terms { -// result.extend(cas_expr_to_luminal_expr(term)); -// } -// for _ in 1..terms.len() { -// result.push(Term::Add); -// } -// result -// } -// Expr::Mul(terms) => { -// if let Expr::Exp(expr, pow) = &terms[1] { -// if let Some(Some(-1)) = pow.as_integer().map(|i| i.to_i32()) { -// assert!(terms.len() == 2); -// let mut v = cas_expr_to_luminal_expr(expr); -// v.extend(cas_expr_to_luminal_expr(&terms[0])); -// v.push(Term::Div); -// return v; -// } -// } -// if let Expr::Exp(expr, pow) = &terms[0] { -// if let Some(Some(-1)) = pow.as_integer().map(|i| i.to_i32()) { -// assert!(terms.len() == 2); -// let mut v = cas_expr_to_luminal_expr(expr); -// v.extend(cas_expr_to_luminal_expr(&terms[1])); -// v.push(Term::Div); -// return v; -// } -// } -// let mut result = Vec::new(); -// for term in terms { -// result.extend(cas_expr_to_luminal_expr(term)); -// } -// for _ in 1..terms.len() { -// result.push(Term::Mul); -// } -// result -// } -// Expr::Exp(base, exponent) => { -// let pow = exponent -// .as_integer() -// .unwrap() -// .to_usize() -// .expect("Pow is not positive!"); -// let mut result = Vec::new(); -// for _ in 0..pow { -// result.extend(cas_expr_to_luminal_expr(base)); -// } -// for _ in 0..pow - 1 { -// result.push(Term::Mul); -// } -// result -// } -// _ => panic!("Invalid term encountered: {expr:?}"), -// } -// } +define_language! { + enum SimpleLanguage { + Num(i32), + "+" = Add([Id; 2]), + "*" = Mul([Id; 2]), + Symbol(Symbol), + } +} + +fn luminal_to_egg(expr: &GenericExpression) -> RecExpr { + let mut stack = Vec::new(); + + for term in expr.terms.iter_ref() { + match term { + Term::Num(_) | Term::Var(_) => stack.push(format!("{term:?}")), + _ => { + let left = stack.pop().unwrap(); + let right = stack.pop().unwrap(); + let subexpr = format!("({term:?} {left} {right})"); + stack.push(subexpr); + } + } + } + + stack.pop().unwrap().parse().unwrap() +} + +fn egg_to_luminal(expr: RecExpr) -> GenericExpression { + fn create_postfix(expr: &[Math]) -> Vec { + match expr.last().unwrap() { + Math::Num(i) => vec![Term::Num(*i)], + Math::Add([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Add], + ] + .concat(), + Math::Sub([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Sub], + ] + .concat(), + Math::Mul([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Mul], + ] + .concat(), + Math::Div([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Div], + ] + .concat(), + Math::Mod([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Mod], + ] + .concat(), + Math::Min([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Min], + ] + .concat(), + Math::Max([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Max], + ] + .concat(), + Math::And([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::And], + ] + .concat(), + Math::Or([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Or], + ] + .concat(), + Math::LessThan([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Lt], + ] + .concat(), + Math::GreaterThanEqual([a, b]) => [ + create_postfix(&expr[..usize::from(*b) + 1]), + create_postfix(&expr[..usize::from(*a) + 1]), + vec![Term::Gte], + ] + .concat(), + Math::Symbol(s) => vec![Term::Var(s.as_str().chars().next().unwrap())], + } + } + let mut terms = S::default(); + terms.extend(create_postfix(expr.as_ref())); + GenericExpression { terms } +} + +type EGraph = egg::EGraph; +type Rewrite = egg::Rewrite; + +define_language! { + enum Math { + Num(i32), + "+" = Add([Id; 2]), + "-" = Sub([Id; 2]), + "*" = Mul([Id; 2]), + "/" = Div([Id; 2]), + "%" = Mod([Id; 2]), + "min" = Min([Id; 2]), + "max" = Max([Id; 2]), + "&&" = And([Id; 2]), + "||" = Or([Id; 2]), + "<" = LessThan([Id; 2]), + ">=" = GreaterThanEqual([Id; 2]), + Symbol(Symbol), + } +} + +#[derive(Default)] +pub struct ConstantFold; +impl Analysis for ConstantFold { + type Data = Option<(i32, PatternAst)>; + + fn make(egraph: &egg::EGraph, enode: &Math) -> Self::Data { + let x = |i: &Id| egraph[*i].data.as_ref().map(|d| d.0); + Some(match enode { + Math::Num(c) => (*c, format!("{}", c).parse().unwrap()), + Math::Add([a, b]) => ( + x(a)?.checked_add(x(b)?)?, + format!("(+ {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::Sub([a, b]) => ( + x(a)?.checked_sub(x(b)?)?, + format!("(- {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::Mul([a, b]) => ( + x(a)?.checked_mul(x(b)?)?, + format!("(* {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::Div([a, b]) if x(b) != Some(0) => { + let (a, b) = (x(a)?, x(b)?); + if a % b != 0 { + return None; + } else { + (a.checked_div(b)?, format!("(/ {a} {b})").parse().unwrap()) + } + } + Math::Mod([a, b]) if x(b) != Some(0) => ( + x(a)?.checked_rem(x(b)?)?, + format!("(% {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::Min([a, b]) if x(b) != Some(0) => ( + x(a)?.min(x(b)?), + format!("(min {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::Max([a, b]) if x(b) != Some(0) => ( + x(a)?.max(x(b)?), + format!("(max {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::And([a, b]) if x(b) != Some(0) => ( + if x(a)? != 0 && x(b)? != 0 { 1 } else { 0 }, + format!("(&& {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::Or([a, b]) if x(b) != Some(0) => ( + if x(a)? != 0 || x(b)? != 0 { 1 } else { 0 }, + format!("(|| {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::LessThan([a, b]) if x(b) != Some(0) => ( + if x(a)? < x(b)? { 1 } else { 0 }, + format!("(< {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + Math::GreaterThanEqual([a, b]) if x(b) != Some(0) => ( + if x(a)? >= x(b)? { 1 } else { 0 }, + format!("(>= {} {})", x(a)?, x(b)?).parse().unwrap(), + ), + _ => return None, + }) + } + + fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { + merge_option(to, from, |a, b| { + assert_eq!(a.0, b.0, "Merged non-equal constants"); + DidMerge(false, false) + }) + } + + fn modify(egraph: &mut EGraph, id: Id) { + let data = egraph[id].data.clone(); + if let Some((c, pat)) = data { + if egraph.are_explanations_enabled() { + egraph.union_instantiations( + &pat, + &format!("{}", c).parse().unwrap(), + &Default::default(), + "constant_fold".to_string(), + ); + } else { + let added = egraph.add(Math::Num(c)); + egraph.union(id, added); + } + // to not prune, comment this out + egraph[id].nodes.retain(|n| n.is_leaf()); + } + } +} + +fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + let var = var.parse().unwrap(); + move |egraph, _, subst| { + if let Some(n) = &egraph[subst[var]].data { + n.0 != 0 + } else { + true + } + } +} + +fn make_rules() -> Vec { + vec![ + // Communative properties + rewrite!("commute-add"; "(+ ?a ?b)" => "(+ ?b ?a)"), + rewrite!("commute-mul"; "(* ?a ?b)" => "(* ?b ?a)"), + rewrite!("commute-min"; "(min ?a ?b)" => "(min ?b ?a)"), + rewrite!("commute-max"; "(max ?a ?b)" => "(max ?b ?a)"), + rewrite!("commute-and"; "(&& ?a ?b)" => "(&& ?b ?a)"), + rewrite!("commute-or"; "(|| ?a ?b)" => "(|| ?b ?a)"), + // Associative properties + rewrite!("assoc-add"; "(+ ?a (+ ?b ?c))" => "(+ (+ ?a ?b) ?c)"), + rewrite!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"), + rewrite!("mul-div-associative"; "(/ (* ?x ?y) ?z)" => "(* ?x (/ ?y ?z))"), + rewrite!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"), + // Simple binary reductions + rewrite!("add-0"; "(+ ?a 0)" => "?a"), + rewrite!("mul-0"; "(* ?a 0)" => "0"), + rewrite!("mul-1"; "(* ?a 1)" => "?a"), + rewrite!("div-1"; "(/ ?a 1)" => "?a"), + rewrite!("div-self"; "(/ ?a ?a)" => "1"), + rewrite!("and-0"; "(&& ?a 0)" => "0"), + rewrite!("and-1"; "(&& ?a 1)" => "?a"), + rewrite!("or-0"; "(|| ?a 0)" => "?a"), + rewrite!("or-1"; "(|| ?a 1)" => "1"), + rewrite!("min-i32-max"; "(min ?a 2147483647)" => "?a"), + rewrite!("max-i32-max"; "(max ?a 2147483647)" => "2147483647"), + rewrite!("recip-mul-div"; "(* ?x (/ 1 ?x))" => "1" if is_not_zero("?x")), + rewrite!("add-zero"; "?a" => "(+ ?a 0)"), + rewrite!("mul-one"; "?a" => "(* ?a 1)"), + rewrite!("cancel-sub"; "(- ?a ?a)" => "0"), + rewrite!("cancel-div"; "(/ ?a ?a)" => "1" if is_not_zero("?a")), + // Other + rewrite!("distribute"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"), + rewrite!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"), + ] +} + +fn egg_simplify(expr: GenericExpression) -> GenericExpression { + // Convert to egg expression + let expr = luminal_to_egg(&expr); + // Simplify + let runner = Runner::default().with_expr(&expr).run(&make_rules()); + let root = runner.roots[0]; + let extractor = Extractor::new(&runner.egraph, AstSize); + let (_, best) = extractor.find_best(root); + // Convert back to luminal expression + egg_to_luminal(best) +} #[cfg(test)] mod tests { @@ -764,7 +961,7 @@ mod tests { let main = Expression::from('x') - 255; let sub = Expression::from('x') / 2; let new = main.substitute('x', sub); - assert_eq!(new, Expression::from(-255) + (Expression::from('x') / 2)); + assert_eq!(new, (Expression::from('x') / 2) + -255); } #[test] @@ -772,9 +969,6 @@ mod tests { let s = BigExpression::from('s'); let expr = (s.clone() * ((s.clone() - 4) + 1)) + (((s.clone() + 1) * ((s.clone() - 4) + 1)) - (s.clone() * ((s.clone() - 4) + 1))); - assert_eq!( - expr.simplify(), - (((Expression::from('s') * Expression::from('s')) + (Expression::from('s') * -2)) + -3) - ); + assert_eq!(expr.simplify(), ((s.clone() + -3) * (s.clone() + 1))); } }