Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Halide interpreter #37

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ log = "0.4.22"
# same version as ruler
z3 = {version = "0.10.0", features = ["static-link-z3"]}
itertools = "0.13.0"
num = "0.3"

serde = "1.0.214"
serde_json = "1.0.132"
250 changes: 98 additions & 152 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use egglog::{EGraph, SerializeConfig};
use ruler::enumo::Pattern;
use ruler::{HashMap, HashSet};
use ruler::{HashMap, HashSet, ValidationResult};
use utils::TERM_PLACEHOLDER;

use std::fmt::Debug;
use std::hash::Hash;
use std::str::FromStr;

use ruler::enumo::{Filter, Metric, Sexp, Workload};
use ruler::enumo::{Sexp, Workload};

use log::info;

Expand Down Expand Up @@ -65,26 +65,16 @@ pub trait Chomper {
result
}

fn run_chompy(
&mut self,
egraph: &mut EGraph,
rules: Vec<Rule>,
mask_to_preds: &HashMap<Vec<bool>, HashSet<String>>,
memo: &mut HashSet<i64>,
) {
let mut found: Vec<bool> = vec![false; rules.len()];

fn run_chompy(&mut self, egraph: &mut EGraph) {
let mut max_eclass_id = 0;

let mut found_rules: HashSet<String> = HashSet::default();

for current_size in 0..MAX_SIZE {
info!("adding programs of size {}:", current_size);

let mut filter = Filter::MetricEq(Metric::Atoms, current_size);
if current_size > 15 {
filter = Filter::And(vec![filter, Filter::Excludes(self.constant_pattern())]);
}
// let mut filter = Filter::MetricEq(Metric::Atoms, current_size);
// if current_size > 4 {
// filter = Filter::And(vec![filter, Filter::Excludes(self.constant_pattern())]);
// }

info!("finding eclass term map...");
let eclass_term_map = self
Expand All @@ -100,20 +90,21 @@ pub trait Chomper {
);

let new_workload = if term_workload.force().is_empty() {
self.atoms().clone().filter(filter)
self.atoms().clone()
} else {
self.productions()
.clone()
.plug(TERM_PLACEHOLDER, &term_workload)
.filter(filter)
};

info!("new workload len: {}", new_workload.force().len());

let atoms = self.atoms().force();

let memo = &mut HashSet::default();

for term in &new_workload.force() {
info!("term: {}", term);
// info!("term: {}", term);
let term_string = self.make_string_not_bad(term.to_string().as_str());
if !atoms.contains(term) && !self.has_var(term) {
continue;
Expand All @@ -125,7 +116,7 @@ pub trait Chomper {
r#"
{term_string}
(set (eclass {term_string}) {max_eclass_id})
"#
"#
)
.as_str(),
)
Expand All @@ -144,109 +135,67 @@ pub trait Chomper {
)
.unwrap();
info!("starting cvec match");
let vals = self.cvec_match(egraph, mask_to_preds, memo);
if vals.non_conditional.is_empty()
|| vals.non_conditional.iter().all(|x| {
found_rules.contains(format!("{:?}", self.generalize_rule(x)).as_str())
})
{
let vals = self.cvec_match(egraph, memo);

if vals.non_conditional.is_empty() && vals.conditional.is_empty() {
break;
}

for (i, rule) in rules.iter().enumerate() {
let lhs = self.make_string_not_bad(rule.lhs.to_string().as_str());
let rhs = self.make_string_not_bad(rule.rhs.to_string().as_str());
if (rule.condition.is_some()
&& egraph
.parse_and_run_program(
None,
format!(
r#"
(check (cond-equal {lhs} {rhs}))
"#
)
.as_str(),
)
.is_ok())
|| (rule.condition.is_none()
&& egraph
.parse_and_run_program(
None,
format!(
r#"
(check (= {lhs} {rhs}))
"#
)
.as_str(),
)
.is_ok())
{
found[i] = true;
}
if found.iter().all(|x| *x) {
return;
info!("found {} non-conditional rules", vals.non_conditional.len());
info!("found {} conditional rules", vals.conditional.len());
for val in &vals.conditional {
let generalized = self.generalize_rule(val);
if let ValidationResult::Valid = self.validate_rule(&generalized) {
if utils::does_rule_have_good_vars(&generalized) {
let lhs =
self.make_string_not_bad(generalized.lhs.to_string().as_str());
let rhs =
self.make_string_not_bad(generalized.rhs.to_string().as_str());
let cond = generalized.condition.as_ref().unwrap();
let pred = self.make_string_not_bad(cond.to_string().as_str());
info!("Conditional rule: if {} then {} ~> {}", pred, lhs, rhs);
self.add_conditional_rewrite(
egraph,
Sexp::from_str(&pred).unwrap(),
Sexp::from_str(&lhs).unwrap(),
Sexp::from_str(&rhs).unwrap(),
);
}
}
}

for val in &vals.non_conditional {
let generalized = self.generalize_rule(val);
if !found_rules.contains(format!("{:?}", generalized).as_str())
&& utils::does_rule_have_good_vars(&generalized)
{
let lhs = self.make_string_not_bad(generalized.lhs.to_string().as_str());
let rhs = self.make_string_not_bad(generalized.rhs.to_string().as_str());
if egraph
.parse_and_run_program(
None,
format!(
r#"
{lhs}
{rhs}
(check (= {lhs} {rhs}))
"#
if let ValidationResult::Valid = self.validate_rule(&generalized) {
if utils::does_rule_have_good_vars(&generalized) {
let lhs =
self.make_string_not_bad(generalized.lhs.to_string().as_str());
let rhs =
self.make_string_not_bad(generalized.rhs.to_string().as_str());

if egraph
.parse_and_run_program(
None,
format!(r#"(check (= {} {}))"#, val.lhs, val.rhs).as_str(),
)
.as_str(),
)
.is_err()
{
let validated = self.get_validated_rule(&generalized);
if found_rules.contains(format!("{:?}", validated).as_str()) {
.is_ok()
{
continue;
}
found_rules.insert(format!("{:?}", validated));
if validated.is_none() {
continue;
}
let validated = validated.unwrap();
if validated.condition.is_none() {
info!("Rule: {} -> {}", validated.lhs, validated.rhs);
self.add_rewrite(egraph, validated.lhs, validated.rhs);
} else {
info!(
"Conditional Rule: if {} then {} -> {}",
validated.condition.clone().unwrap(),
validated.lhs,
validated.rhs
);
self.add_conditional_rewrite(
egraph,
validated.condition.unwrap(),
validated.lhs,
validated.rhs,
);
}

self.add_rewrite(
egraph,
Sexp::from_str(&lhs).unwrap(),
Sexp::from_str(&rhs).unwrap(),
);
// TODO: derivability check here
}
} else {
// info!(
// "perfect cvec match but failed validation: {} ~> {}",
// val.lhs, val.rhs
// );
}
}

for val in &vals.conditional {
self.add_conditional_rewrite(
egraph,
val.condition.clone().unwrap(),
val.lhs.clone(),
val.rhs.clone(),
);
}
}
}

Expand Down Expand Up @@ -278,9 +227,15 @@ pub trait Chomper {
let mut id_to_gen_id: HashMap<String, String> = HashMap::default();
let new_lhs = self.generalize_sexp(rule.lhs.clone(), &mut id_to_gen_id);
let new_rhs = self.generalize_sexp(rule.rhs.clone(), &mut id_to_gen_id);

let condition = rule
.condition
.as_ref()
.map(|cond| self.generalize_sexp(cond.clone(), &mut id_to_gen_id));

Rule {
// TODO: later
condition: None,
condition,
lhs: new_lhs,
rhs: new_rhs,
}
Expand Down Expand Up @@ -314,7 +269,6 @@ pub trait Chomper {
fn cvec_match(
&mut self,
egraph: &mut EGraph,
mask_to_preds: &HashMap<Vec<bool>, HashSet<String>>,
// keeps track of what eclass IDs we've seen.
memo: &mut HashSet<i64>,
) -> Rules {
Expand All @@ -323,12 +277,14 @@ pub trait Chomper {
conditional: vec![],
};

println!("hi from cvec match");
let mask_to_preds = self.make_mask_to_preds();

info!("hi from cvec match");
let serialized = egraph.serialize(SerializeConfig::default());
println!("eclasses in egraph: {}", serialized.classes().len());
println!("nodes in egraph: {}", serialized.nodes.len());
info!("eclasses in egraph: {}", serialized.classes().len());
info!("nodes in egraph: {}", serialized.nodes.len());
let eclass_term_map: HashMap<i64, Sexp> = self.reset_eclass_term_map(egraph);
// println!("eclass term map len: {}", eclass_term_map.len());
info!("eclass term map len: {}", eclass_term_map.len());
let ec_keys: Vec<&i64> = eclass_term_map.keys().collect();
for i in 0..ec_keys.len() {
let ec1 = ec_keys[i];
Expand Down Expand Up @@ -357,11 +313,6 @@ pub trait Chomper {
lhs: term1.clone(),
rhs: term2.clone(),
});
result.non_conditional.push(Rule {
condition: None,
lhs: term2.clone(),
rhs: term1.clone(),
});
} else {
if egraph
.parse_and_run_program(
Expand All @@ -386,36 +337,27 @@ pub trait Chomper {
}

if !has_meaningful_diff {
println!("no meaningful diff");
info!("no meaningful diff");
continue;
}

// sufficient and necessary conditions.
// we may want to experiment with just having sufficient conditions.
let masks = mask_to_preds.keys().filter(|mask| {
mask.iter()
.zip(same_vals.iter())
.all(|(mask_val, same_val)| mask_val == same_val)
});
// if the mask is all false, then skip it.
if same_vals.iter().all(|x| !x) {
continue;
}

for mask in masks {
// if the mask is completely false, skip it.
if mask.iter().all(|x| !x) {
continue;
}
let preds = mask_to_preds.get(mask).unwrap();
for pred in preds {
result.conditional.push(Rule {
condition: Some(Sexp::from_str(pred).unwrap()),
lhs: term1.clone(),
rhs: term2.clone(),
});
result.conditional.push(Rule {
condition: Some(Sexp::from_str(pred).unwrap()),
lhs: term2.clone(),
rhs: term1.clone(),
});
}
// sufficient and necessary conditions.
if !mask_to_preds.contains_key(&same_vals) {
continue;
}
let preds = mask_to_preds.get(&same_vals).unwrap();
for pred in preds {
let rule = Rule {
condition: Some(Sexp::from_str(pred).unwrap()),
lhs: term1.clone(),
rhs: term2.clone(),
};
result.conditional.push(rule);
}
}
}
Expand All @@ -426,6 +368,10 @@ pub trait Chomper {
fn add_rewrite(&mut self, egraph: &mut EGraph, lhs: Sexp, rhs: Sexp) {
let term1 = self.make_string_not_bad(lhs.to_string().as_str());
let term2 = self.make_string_not_bad(rhs.to_string().as_str());
if term1 == "?a" {
return;
}
info!("Rule: {} ~> {}", term1, term2);
egraph
.parse_and_run_program(
None,
Expand Down Expand Up @@ -454,11 +400,11 @@ pub trait Chomper {
// let _pred = self.make_string_not_bad(cond.to_string().as_str());
// let term1 = self.make_string_not_bad(lhs.to_string().as_str());
// let term2 = self.make_string_not_bad(rhs.to_string().as_str());
// println!(
// info!(
// "adding conditional rewrite: {} -> {} if {}",
// term1, term2, _pred
// );
// println!("term2 has cvec: {:?}", self.interpret_term(&rhs));
// info!("term2 has cvec: {:?}", self.interpret_term(&rhs));
// egraph
// .parse_and_run_program(
// None,
Expand Down Expand Up @@ -490,8 +436,8 @@ pub trait Chomper {
fn productions(&self) -> Workload;
fn atoms(&self) -> Workload;
fn make_preds(&self) -> Workload;
fn get_env(&self) -> &HashMap<String, Vec<Value<Self>>>;
fn get_validated_rule(&self, rule: &Rule) -> Option<Rule>;
fn get_env(&self) -> &HashMap<String, CVec<Self>>;
fn validate_rule(&self, rule: &Rule) -> ValidationResult;
fn interpret_term(&mut self, term: &ruler::enumo::Sexp) -> CVec<Self>;
fn interpret_pred(&mut self, term: &ruler::enumo::Sexp) -> Vec<bool>;
fn constant_pattern(&self) -> Pattern;
Expand Down
Loading
Loading