diff --git a/Cargo.toml b/Cargo.toml index 34f2172..1b0b3ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ log = "0.4.22" # same version as ruler z3 = {version = "0.10.0", features = ["static-link-z3"]} itertools = "0.13.0" +num = "0.3" serde = "1.0.214" serde_json = "1.0.132" diff --git a/src/lib.rs b/src/lib.rs index 9d3f77b..c23baff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,13 @@ use egglog::{EGraph, SerializeConfig}; use ruler::enumo::Pattern; -use ruler::{HashMap, HashSet}; +use ruler::{HashMap, HashSet, ValidationResult}; use utils::TERM_PLACEHOLDER; use std::fmt::Debug; use std::hash::Hash; use std::str::FromStr; -use ruler::enumo::{Filter, Metric, Sexp, Workload}; +use ruler::enumo::{Sexp, Workload}; use log::info; @@ -65,26 +65,16 @@ pub trait Chomper { result } - fn run_chompy( - &mut self, - egraph: &mut EGraph, - rules: Vec, - mask_to_preds: &HashMap, HashSet>, - memo: &mut HashSet, - ) { - let mut found: Vec = vec![false; rules.len()]; - + fn run_chompy(&mut self, egraph: &mut EGraph) { let mut max_eclass_id = 0; - let mut found_rules: HashSet = HashSet::default(); - for current_size in 0..MAX_SIZE { info!("adding programs of size {}:", current_size); - let mut filter = Filter::MetricEq(Metric::Atoms, current_size); - if current_size > 15 { - filter = Filter::And(vec![filter, Filter::Excludes(self.constant_pattern())]); - } + // let mut filter = Filter::MetricEq(Metric::Atoms, current_size); + // if current_size > 4 { + // filter = Filter::And(vec![filter, Filter::Excludes(self.constant_pattern())]); + // } info!("finding eclass term map..."); let eclass_term_map = self @@ -100,20 +90,21 @@ pub trait Chomper { ); let new_workload = if term_workload.force().is_empty() { - self.atoms().clone().filter(filter) + self.atoms().clone() } else { self.productions() .clone() .plug(TERM_PLACEHOLDER, &term_workload) - .filter(filter) }; info!("new workload len: {}", new_workload.force().len()); let atoms = self.atoms().force(); + let memo = &mut HashSet::default(); + for term in &new_workload.force() { - info!("term: {}", term); + // info!("term: {}", term); let term_string = self.make_string_not_bad(term.to_string().as_str()); if !atoms.contains(term) && !self.has_var(term) { continue; @@ -125,7 +116,7 @@ pub trait Chomper { r#" {term_string} (set (eclass {term_string}) {max_eclass_id}) - "# + "# ) .as_str(), ) @@ -144,109 +135,67 @@ pub trait Chomper { ) .unwrap(); info!("starting cvec match"); - let vals = self.cvec_match(egraph, mask_to_preds, memo); - if vals.non_conditional.is_empty() - || vals.non_conditional.iter().all(|x| { - found_rules.contains(format!("{:?}", self.generalize_rule(x)).as_str()) - }) - { + let vals = self.cvec_match(egraph, memo); + + if vals.non_conditional.is_empty() && vals.conditional.is_empty() { break; } - for (i, rule) in rules.iter().enumerate() { - let lhs = self.make_string_not_bad(rule.lhs.to_string().as_str()); - let rhs = self.make_string_not_bad(rule.rhs.to_string().as_str()); - if (rule.condition.is_some() - && egraph - .parse_and_run_program( - None, - format!( - r#" - (check (cond-equal {lhs} {rhs})) - "# - ) - .as_str(), - ) - .is_ok()) - || (rule.condition.is_none() - && egraph - .parse_and_run_program( - None, - format!( - r#" - (check (= {lhs} {rhs})) - "# - ) - .as_str(), - ) - .is_ok()) - { - found[i] = true; - } - if found.iter().all(|x| *x) { - return; + info!("found {} non-conditional rules", vals.non_conditional.len()); + info!("found {} conditional rules", vals.conditional.len()); + for val in &vals.conditional { + let generalized = self.generalize_rule(val); + if let ValidationResult::Valid = self.validate_rule(&generalized) { + if utils::does_rule_have_good_vars(&generalized) { + let lhs = + self.make_string_not_bad(generalized.lhs.to_string().as_str()); + let rhs = + self.make_string_not_bad(generalized.rhs.to_string().as_str()); + let cond = generalized.condition.as_ref().unwrap(); + let pred = self.make_string_not_bad(cond.to_string().as_str()); + info!("Conditional rule: if {} then {} ~> {}", pred, lhs, rhs); + self.add_conditional_rewrite( + egraph, + Sexp::from_str(&pred).unwrap(), + Sexp::from_str(&lhs).unwrap(), + Sexp::from_str(&rhs).unwrap(), + ); + } } } - for val in &vals.non_conditional { let generalized = self.generalize_rule(val); - if !found_rules.contains(format!("{:?}", generalized).as_str()) - && utils::does_rule_have_good_vars(&generalized) - { - let lhs = self.make_string_not_bad(generalized.lhs.to_string().as_str()); - let rhs = self.make_string_not_bad(generalized.rhs.to_string().as_str()); - if egraph - .parse_and_run_program( - None, - format!( - r#" - {lhs} - {rhs} - (check (= {lhs} {rhs})) - "# + if let ValidationResult::Valid = self.validate_rule(&generalized) { + if utils::does_rule_have_good_vars(&generalized) { + let lhs = + self.make_string_not_bad(generalized.lhs.to_string().as_str()); + let rhs = + self.make_string_not_bad(generalized.rhs.to_string().as_str()); + + if egraph + .parse_and_run_program( + None, + format!(r#"(check (= {} {}))"#, val.lhs, val.rhs).as_str(), ) - .as_str(), - ) - .is_err() - { - let validated = self.get_validated_rule(&generalized); - if found_rules.contains(format!("{:?}", validated).as_str()) { + .is_ok() + { continue; } - found_rules.insert(format!("{:?}", validated)); - if validated.is_none() { - continue; - } - let validated = validated.unwrap(); - if validated.condition.is_none() { - info!("Rule: {} -> {}", validated.lhs, validated.rhs); - self.add_rewrite(egraph, validated.lhs, validated.rhs); - } else { - info!( - "Conditional Rule: if {} then {} -> {}", - validated.condition.clone().unwrap(), - validated.lhs, - validated.rhs - ); - self.add_conditional_rewrite( - egraph, - validated.condition.unwrap(), - validated.lhs, - validated.rhs, - ); - } + + self.add_rewrite( + egraph, + Sexp::from_str(&lhs).unwrap(), + Sexp::from_str(&rhs).unwrap(), + ); + // TODO: derivability check here } + } else { + // info!( + // "perfect cvec match but failed validation: {} ~> {}", + // val.lhs, val.rhs + // ); } } - - for val in &vals.conditional { - self.add_conditional_rewrite( - egraph, - val.condition.clone().unwrap(), - val.lhs.clone(), - val.rhs.clone(), - ); - } } } @@ -278,9 +227,15 @@ pub trait Chomper { let mut id_to_gen_id: HashMap = HashMap::default(); let new_lhs = self.generalize_sexp(rule.lhs.clone(), &mut id_to_gen_id); let new_rhs = self.generalize_sexp(rule.rhs.clone(), &mut id_to_gen_id); + + let condition = rule + .condition + .as_ref() + .map(|cond| self.generalize_sexp(cond.clone(), &mut id_to_gen_id)); + Rule { // TODO: later - condition: None, + condition, lhs: new_lhs, rhs: new_rhs, } @@ -314,7 +269,6 @@ pub trait Chomper { fn cvec_match( &mut self, egraph: &mut EGraph, - mask_to_preds: &HashMap, HashSet>, // keeps track of what eclass IDs we've seen. memo: &mut HashSet, ) -> Rules { @@ -323,12 +277,14 @@ pub trait Chomper { conditional: vec![], }; - println!("hi from cvec match"); + let mask_to_preds = self.make_mask_to_preds(); + + info!("hi from cvec match"); let serialized = egraph.serialize(SerializeConfig::default()); - println!("eclasses in egraph: {}", serialized.classes().len()); - println!("nodes in egraph: {}", serialized.nodes.len()); + info!("eclasses in egraph: {}", serialized.classes().len()); + info!("nodes in egraph: {}", serialized.nodes.len()); let eclass_term_map: HashMap = self.reset_eclass_term_map(egraph); - // println!("eclass term map len: {}", eclass_term_map.len()); + info!("eclass term map len: {}", eclass_term_map.len()); let ec_keys: Vec<&i64> = eclass_term_map.keys().collect(); for i in 0..ec_keys.len() { let ec1 = ec_keys[i]; @@ -357,11 +313,6 @@ pub trait Chomper { lhs: term1.clone(), rhs: term2.clone(), }); - result.non_conditional.push(Rule { - condition: None, - lhs: term2.clone(), - rhs: term1.clone(), - }); } else { if egraph .parse_and_run_program( @@ -386,36 +337,27 @@ pub trait Chomper { } if !has_meaningful_diff { - println!("no meaningful diff"); + info!("no meaningful diff"); continue; } - // sufficient and necessary conditions. - // we may want to experiment with just having sufficient conditions. - let masks = mask_to_preds.keys().filter(|mask| { - mask.iter() - .zip(same_vals.iter()) - .all(|(mask_val, same_val)| mask_val == same_val) - }); + // if the mask is all false, then skip it. + if same_vals.iter().all(|x| !x) { + continue; + } - for mask in masks { - // if the mask is completely false, skip it. - if mask.iter().all(|x| !x) { - continue; - } - let preds = mask_to_preds.get(mask).unwrap(); - for pred in preds { - result.conditional.push(Rule { - condition: Some(Sexp::from_str(pred).unwrap()), - lhs: term1.clone(), - rhs: term2.clone(), - }); - result.conditional.push(Rule { - condition: Some(Sexp::from_str(pred).unwrap()), - lhs: term2.clone(), - rhs: term1.clone(), - }); - } + // sufficient and necessary conditions. + if !mask_to_preds.contains_key(&same_vals) { + continue; + } + let preds = mask_to_preds.get(&same_vals).unwrap(); + for pred in preds { + let rule = Rule { + condition: Some(Sexp::from_str(pred).unwrap()), + lhs: term1.clone(), + rhs: term2.clone(), + }; + result.conditional.push(rule); } } } @@ -426,6 +368,10 @@ pub trait Chomper { fn add_rewrite(&mut self, egraph: &mut EGraph, lhs: Sexp, rhs: Sexp) { let term1 = self.make_string_not_bad(lhs.to_string().as_str()); let term2 = self.make_string_not_bad(rhs.to_string().as_str()); + if term1 == "?a" { + return; + } + info!("Rule: {} ~> {}", term1, term2); egraph .parse_and_run_program( None, @@ -454,11 +400,11 @@ pub trait Chomper { // let _pred = self.make_string_not_bad(cond.to_string().as_str()); // let term1 = self.make_string_not_bad(lhs.to_string().as_str()); // let term2 = self.make_string_not_bad(rhs.to_string().as_str()); - // println!( + // info!( // "adding conditional rewrite: {} -> {} if {}", // term1, term2, _pred // ); - // println!("term2 has cvec: {:?}", self.interpret_term(&rhs)); + // info!("term2 has cvec: {:?}", self.interpret_term(&rhs)); // egraph // .parse_and_run_program( // None, @@ -490,8 +436,8 @@ pub trait Chomper { fn productions(&self) -> Workload; fn atoms(&self) -> Workload; fn make_preds(&self) -> Workload; - fn get_env(&self) -> &HashMap>>; - fn get_validated_rule(&self, rule: &Rule) -> Option; + fn get_env(&self) -> &HashMap>; + fn validate_rule(&self, rule: &Rule) -> ValidationResult; fn interpret_term(&mut self, term: &ruler::enumo::Sexp) -> CVec; fn interpret_pred(&mut self, term: &ruler::enumo::Sexp) -> Vec; fn constant_pattern(&self) -> Pattern; diff --git a/tests/egglog/halide.egg b/tests/egglog/halide.egg new file mode 100644 index 0000000..e5b15d7 --- /dev/null +++ b/tests/egglog/halide.egg @@ -0,0 +1,38 @@ +;;; Halide language definition. + +(datatype Expr + (Lit i64) + (Var String) + (Lt Expr Expr) + (Leq Expr Expr) + (Eq Expr Expr) + (Neq Expr Expr) + (Implies Expr Expr) + (Not Expr) + (Neg Expr) + (And Expr Expr) + (Or Expr Expr) + (Xor Expr Expr) + (Add Expr Expr) + (Sub Expr Expr) + (Mul Expr Expr) + (Div Expr Expr) + (Min Expr Expr) + (Max Expr Expr) + (Select Expr Expr Expr) +) + + +(function eclass (Expr) i64 :merge (min old new)) + +(ruleset eclass-report) +(ruleset non-cond-rewrites) +(ruleset cond-rewrites) + +(rule + ((eclass ?bvterm)) + ((extract "eclass:") + (extract (eclass ?bvterm)) + (extract "candidate term:") + (extract ?bvterm)) + :ruleset eclass-report) diff --git a/tests/halide.rs b/tests/halide.rs new file mode 100644 index 0000000..a51ec00 --- /dev/null +++ b/tests/halide.rs @@ -0,0 +1,423 @@ +use chompy::{CVec, Chomper}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use ruler::{ + enumo::{Sexp, Workload}, + HashMap, ValidationResult, +}; + +use z3::ast::Ast; + +use chompy::utils::TERM_PLACEHOLDER; +use num::Zero; + +pub const CVEC_LEN: usize = 20; + +pub struct HalideChomper { + pub env: ruler::HashMap>, +} + +impl Chomper for HalideChomper { + type Constant = i64; + type Value = i64; + + fn productions(&self) -> ruler::enumo::Workload { + Workload::new(&[ + format!( + "(ternary {} {} {})", + TERM_PLACEHOLDER, TERM_PLACEHOLDER, TERM_PLACEHOLDER + ), + format!("(binary {} {})", TERM_PLACEHOLDER, TERM_PLACEHOLDER), + format!("(unary {})", TERM_PLACEHOLDER), + ]) + .plug("ternary", &Workload::new(&["Select"])) + .plug( + "binary", + &Workload::new(&[ + "Lt", "Leq", "Eq", "Neq", "Implies", "And", "Or", "Xor", "Add", "Sub", "Mul", + "Div", "Min", "Max", + ]), + ) + .plug("unary", &Workload::new(&["Not", "Neg"])) + } + + fn atoms(&self) -> Workload { + Workload::new(&["(Var a)", "(Var b)", "(Lit 1)", "(Lit 0)"]) + } + + fn matches_var_pattern(&self, term: &ruler::enumo::Sexp) -> bool { + match term { + Sexp::List(l) => l.len() == 2 && l[0] == Sexp::Atom("Var".to_string()), + _ => false, + } + } + + fn constant_pattern(&self) -> ruler::enumo::Pattern { + "(Lit ?x)".parse().unwrap() + } + fn interpret_term(&mut self, term: &ruler::enumo::Sexp) -> chompy::CVec { + match term { + Sexp::Atom(a) => panic!("Unexpected atom {}", a), + Sexp::List(l) => { + assert!(l.len() > 1); + let op = l[0].to_string(); + match op.as_str() { + "Lit" => { + if let Sexp::Atom(num) = &l[1] { + let parsed: i64 = num.parse().unwrap(); + vec![Some(parsed); CVEC_LEN] + } else { + panic!("Term with weird structure: {}", term) + } + } + "Var" => self.get_env().get(&l[1].to_string()).unwrap().clone(), + _ => { + let zero: i64 = 0; + let one: i64 = 1; + + let children: Vec> = + l[1..].iter().map(|t| self.interpret_term(t)).collect(); + + if let Sexp::Atom(op) = &l[0] { + match children.len() { + 1 => { + let f = |a: Option| -> Option { + if a.is_none() { + return None; + } + let a = a.unwrap(); + match op.as_str() { + "Not" => { + if a == zero { + Some(one.clone()) + } else { + Some(zero.clone()) + } + } + "Neg" => Some(-a), + _ => panic!("Unexpected unary operator {}", op), + } + }; + + children[0].iter().map(|a| f(*a)).collect() + } + 2 => { + let f = |(a, b): (Option, Option)| -> Option { + if a.is_none() || b.is_none() { + return None; + } + let a = a.unwrap(); + let b = b.unwrap(); + match op.as_str() { + "Lt" => { + Some(if a < b { one.clone() } else { zero.clone() }) + } + "Leq" => Some(if a <= b { + one.clone() + } else { + zero.clone() + }), + "Eq" => Some(if a == b { + one.clone() + } else { + zero.clone() + }), + "Neq" => Some(if a != b { + one.clone() + } else { + zero.clone() + }), + "Implies" => { + let p = a != zero; + let q = b != zero; + Some(if p || !q { + one.clone() + } else { + zero.clone() + }) + } + "And" => { + let abool = a != zero; + let bbool = b != zero; + if abool && bbool { + Some(one.clone()) + } else { + Some(zero.clone()) + } + } + "Or" => { + let abool = a != zero; + let bbool = b != zero; + if abool || bbool { + Some(one.clone()) + } else { + Some(zero.clone()) + } + } + "Xor" => { + let abool = a != zero; + let bbool = b != zero; + if abool ^ bbool { + Some(one.clone()) + } else { + Some(zero.clone()) + } + } + "Add" => a.checked_add(b), + "Sub" => a.checked_sub(b), + "Mul" => a.checked_mul(b), + "Div" => { + if b.is_zero() { + Some(zero.clone()) + } else { + a.checked_div(b) + } + } + "Min" => Some(a.min(b)), + "Max" => Some(a.max(b)), + _ => panic!("Unexpected binary operator {}", op), + } + }; + children[0] + .iter() + .zip(children[1].iter()) + .map(|(a, b)| f((*a, *b))) + .into_iter() + .collect() + } + 3 => { + let f = |(a, b, c): (Option, Option, Option)| { + if a.is_none() || b.is_none() || c.is_none() { + return None; + } + let a = a.unwrap(); + let b = b.unwrap(); + let c = c.unwrap(); + match op.as_str() { + "Select" => Some(if a != zero { b } else { c }), + _ => panic!("Unexpected ternary operator {}", op), + } + }; + children[0] + .iter() + .zip(children[1].iter()) + .zip(children[2].iter()) + .map(|((a, b), c)| f((*a, *b, *c))) + .into_iter() + .collect() + } + _ => todo!(), + } + } else { + panic!("Expected atom for function, found {}", op); + } + } + } + } + } + } + + fn interpret_pred(&mut self, term: &ruler::enumo::Sexp) -> Vec { + let cvec = self.interpret_term(term); + cvec.iter() + .map(|x| { + if x.is_none() { + panic!( + "Expected concrete value for cvec {:?}, but found None", + cvec + ); + } + let x = x.unwrap(); + if x == 0 { + false + } else if x == 1 { + true + } else { + panic!("Expected 0 or 1, but found {} in {:?}", x, cvec); + } + }) + .collect() + } + + fn validate_rule(&self, rule: &chompy::Rule) -> ValidationResult { + let mut cfg = z3::Config::new(); + cfg.set_timeout_msec(1000); + let ctx = z3::Context::new(&cfg); + let solver = z3::Solver::new(&ctx); + let lexpr = sexp_to_z3(&ctx, &rule.lhs); + let rexpr = sexp_to_z3(&ctx, &rule.rhs); + if rule.condition.is_some() { + let assumption = rule.condition.clone().unwrap(); + let aexpr = sexp_to_z3(&ctx, &assumption); + let zero = z3::ast::Int::from_i64(&ctx, 0); + let cond = z3::ast::Bool::not(&aexpr._eq(&zero)); + solver.assert(&z3::ast::Bool::implies(&cond, &lexpr._eq(&rexpr)).not()); + } else { + solver.assert(&lexpr._eq(&rexpr).not()); + } + match solver.check() { + z3::SatResult::Unsat => ValidationResult::Valid, + z3::SatResult::Unknown => ValidationResult::Unknown, + z3::SatResult::Sat => ValidationResult::Invalid, + } + } + + fn make_preds(&self) -> Workload { + // TODO: expand this to have a larger range of predicates. + let depth_1 = Workload::new(&["(Lt var var)", "(Leq var var)", "(Eq var var)"]) + .plug("var", &self.atoms()); + depth_1 + } + + fn get_env(&self) -> &ruler::HashMap> { + &self.env + } +} + +impl HalideChomper { + fn make_env(rng: &mut StdRng) -> HashMap>> { + let mut env = HashMap::default(); + let dummy = HalideChomper { env: env.clone() }; + for atom in &dummy.atoms().force() { + if let Sexp::List(l) = atom { + let atom_type = l[0].clone(); + if atom_type.to_string() == "Var" { + let id = &l[1]; + println!("id: {:?}", id); + let name = id.to_string(); + let mut values = Vec::new(); + for _ in 0..CVEC_LEN { + values.push(Some(rng.gen_range(-10..10))); + } + env.insert(name, values); + } + } + } + env + } +} + +fn sexp_to_z3<'a>(ctx: &'a z3::Context, sexp: &Sexp) -> z3::ast::Int<'a> { + match sexp { + Sexp::Atom(a) => { + // assert that a begins with question mark + assert!(a.starts_with("?")); + z3::ast::Int::new_const(ctx, a[1..].to_string()) + } + Sexp::List(l) => { + assert!(l.len() > 1); + let op = l[0].to_string(); + match op.as_str() { + "Lit" => { + if let Sexp::Atom(num) = &l[1] { + let parsed: i64 = num.parse().unwrap(); + z3::ast::Int::from_i64(ctx, parsed) + } else { + panic!("Lit with weird structure: {:?}", sexp) + } + } + _ => { + let children: Vec = + l[1..].iter().map(|t| sexp_to_z3(ctx, t)).collect(); + let zero = z3::ast::Int::from_i64(ctx, 0); + let one = z3::ast::Int::from_i64(ctx, 1); + let op = l[0].to_string(); + match op.as_str() { + "Lt" => z3::ast::Bool::ite( + &z3::ast::Int::lt(&children[0], &children[1]), + &one, + &zero, + ), + "Leq" => z3::ast::Bool::ite( + &z3::ast::Int::le(&children[0], &children[1]), + &one, + &zero, + ), + "Eq" => z3::ast::Bool::ite( + &z3::ast::Int::_eq(&children[0], &children[1]), + &one, + &zero, + ), + "Neq" => z3::ast::Bool::ite( + &z3::ast::Int::_eq(&children[0], &children[1]), + &zero, + &one, + ), + "Implies" => { + let l_not_z = z3::ast::Bool::not(&children[0]._eq(&zero)); + let r_not_z = z3::ast::Bool::not(&children[1]._eq(&zero)); + z3::ast::Bool::ite( + &z3::ast::Bool::implies(&l_not_z, &r_not_z), + &one, + &zero, + ) + } + "Not" => z3::ast::Bool::ite(&children[0]._eq(&zero), &one, &zero), + "Neg" => z3::ast::Int::unary_minus(&children[0]), + "And" => { + let l_not_z = z3::ast::Bool::not(&children[0]._eq(&zero)); + let r_not_z = z3::ast::Bool::not(&children[1]._eq(&zero)); + z3::ast::Bool::ite( + &z3::ast::Bool::and(ctx, &[&l_not_z, &r_not_z]), + &one, + &zero, + ) + } + "Or" => { + let l_not_z = z3::ast::Bool::not(&children[0]._eq(&zero)); + let r_not_z = z3::ast::Bool::not(&children[1]._eq(&zero)); + z3::ast::Bool::ite( + &z3::ast::Bool::or(ctx, &[&l_not_z, &r_not_z]), + &one, + &zero, + ) + } + "Xor" => { + let l_not_z = z3::ast::Bool::not(&children[0]._eq(&zero)); + let r_not_z = z3::ast::Bool::not(&children[1]._eq(&zero)); + z3::ast::Bool::ite(&z3::ast::Bool::xor(&l_not_z, &r_not_z), &one, &zero) + } + "Add" => z3::ast::Int::add(ctx, &[&children[0], &children[1]]), + "Sub" => z3::ast::Int::sub(ctx, &[&children[0], &children[1]]), + "Mul" => z3::ast::Int::mul(ctx, &[&children[0], &children[1]]), + "Div" => z3::ast::Bool::ite( + &children[1]._eq(&zero), + &zero, + &z3::ast::Int::div(&children[0], &children[1]), + ), + "Min" => z3::ast::Bool::ite( + &z3::ast::Int::le(&children[0], &children[1]), + &children[0], + &children[1], + ), + "Max" => z3::ast::Bool::ite( + &z3::ast::Int::le(&children[0], &children[1]), + &children[1], + &children[0], + ), + "Select" => { + let cond = z3::ast::Bool::not(&children[0]._eq(&zero)); + z3::ast::Bool::ite(&cond, &children[1], &children[2]) + } + _ => panic!("Unexpected operator {}", op), + } + } + } + } + } +} + +pub mod tests { + use chompy::init_egraph; + use egglog::EGraph; + + use super::*; + + #[test] + fn try_inference() { + let env = HalideChomper::make_env(&mut StdRng::seed_from_u64(0)); + let mut chomper = HalideChomper { env }; + let mut egraph = EGraph::default(); + init_egraph!(egraph, "./egglog/halide.egg"); + chomper.run_chompy(&mut egraph); + } +}