Skip to content

Commit

Permalink
Dump progress
Browse files Browse the repository at this point in the history
  • Loading branch information
ninehusky committed Oct 23, 2024
1 parent 6f719ad commit 9da65bd
Show file tree
Hide file tree
Showing 7 changed files with 425 additions and 779 deletions.
164 changes: 111 additions & 53 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use egglog::EGraph;
use egglog::{EGraph, SerializeConfig};
use ruler::{HashMap, HashSet, ValidationResult};
use utils::{get_ast_size, TERM_PLACEHOLDER};

Expand Down Expand Up @@ -88,46 +88,58 @@ pub trait Chomper {

let mut max_eclass_id = 0;

// corpus[i] contains all programs of size i.
let mut corpus = self.make_initial_corpus();

let mut old_workload = self.atoms();
let mut old_workload = Workload::empty();

// invariant: `corpus` contains all programs of size `i`.
for current_size in 0..MAX_SIZE {
println!("programs of size {}:", current_size);
println!("adding programs of size {}:", current_size);

let mut filter = Filter::MetricEq(Metric::Atoms, current_size);
if current_size > 2 {
if current_size > 15 {
filter = Filter::And(vec![
filter,
Filter::Excludes("(Bitvector ?x (ValueNum ?y))".parse().unwrap()),
]);
}

let new_workload = self
.productions()
.clone()
.plug(TERM_PLACEHOLDER, &old_workload)
.filter(filter);

old_workload = old_workload.append(new_workload.clone());

println!("old workload has length: {}", old_workload.force().len());
info!("finding eclass term map...");
let eclass_term_map = self
.reset_eclass_term_map(egraph)
.values()
.cloned()
.collect::<Vec<_>>();
info!("eclass term map len: {}", eclass_term_map.len());
let term_workload = Workload::new(
eclass_term_map
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>(),
);

let new_workload = if term_workload.force().is_empty() {
self.atoms().clone()
} else {
self.productions()
.clone()
.plug(TERM_PLACEHOLDER, &term_workload)
.filter(filter)
};

println!("new workload len: {}", new_workload.force().len());

for term in &new_workload.force() {
if get_ast_size(term) != current_size {
panic!();
}
println!("term: {}", term);
// TODO: re-include this.
// if get_ast_size(term) < current_size {
// panic!();
// }
let term_string = self.make_string_not_bad(term.to_string().as_str());
egraph
.parse_and_run_program(
None,
format!(
r#"
{term_string}
(set (eclass {term_string}) {max_eclass_id})
{term_string}
(set (eclass {term_string}) {max_eclass_id})
"#
)
.as_str(),
Expand All @@ -145,35 +157,57 @@ pub trait Chomper {
break;
}

// loop through conditionals and non-conditionals
for val in vals.conditional.iter().chain(vals.non_conditional.iter()) {
for (i, rule) in rules.iter().enumerate() {
if val.lhs == rule.lhs
&& val.rhs == rule.rhs
&& val.condition == rule.condition
{
found[i] = true;
if found.iter().all(|x| *x) {
return;
}
}
for (i, rule) in rules.iter().enumerate() {
let lhs = self.make_string_not_bad(rule.lhs.to_string().as_str());
let rhs = self.make_string_not_bad(rule.rhs.to_string().as_str());
if egraph
.parse_and_run_program(
None,
format!(
r#"
{lhs}
{rhs}
(run-schedule
(saturate non-cond-rewrites))
(check (= {lhs} {rhs}))
"#
)
.as_str(),
)
.is_ok()
{
found[i] = true;
};
if found.iter().all(|x| *x) {
return;
}
}

for val in &vals.non_conditional {
println!("found rule: {} -> {}", val.lhs, val.rhs);
self.add_rewrite(egraph, val.lhs.clone(), val.rhs.clone());
let lhs = self.make_string_not_bad(val.lhs.to_string().as_str());
let rhs = self.make_string_not_bad(val.rhs.to_string().as_str());
if egraph
.parse_and_run_program(
None,
format!(
r#"
{lhs}
{rhs}
(run-schedule
(saturate non-cond-rewrites))
(check (= {lhs} {rhs}))
"#
)
.as_str(),
)
.is_err()
{
self.add_rewrite(egraph, val.lhs.clone(), val.rhs.clone());
};
}
egraph
.parse_and_run_program(
None,
r#"
(run-schedule
(saturate non-cond-rewrites))"#,
)
.unwrap();
}
}

panic!("not all rules were found");
}

Expand Down Expand Up @@ -224,6 +258,7 @@ pub trait Chomper {
}
for ec2 in ec_keys.iter().skip(i + 1) {
let term2 = eclass_term_map.get(ec2).unwrap();

let cvec2 = self.interpret_term(term2);

if cvec2.iter().all(|x| x.is_none()) {
Expand All @@ -242,6 +277,17 @@ pub trait Chomper {
rhs: term1.clone(),
});
} else {
if egraph
.parse_and_run_program(
None,
format!("(check (cond-equal {term1} {term2}))").as_str(),
)
.is_ok()
{
println!("skipping");
continue;
}

// TODO: check if they are conditionally equal
let mut has_meaningful_diff = false;
let mut matching_count = 0;
Expand All @@ -253,15 +299,17 @@ pub trait Chomper {
has_meaningful_diff = true;
}
same_vals.push(has_match);
matching_count += 1;
if has_match {
matching_count += 1;
}
}

if !has_meaningful_diff {
continue;
}

// filter out bad predicates that only match on one value
if matching_count < 2 {
if matching_count < 3 {
continue;
}

Expand All @@ -272,7 +320,7 @@ pub trait Chomper {
.all(|(mask_val, same_val)| {
// pred --> lhs == rhs
// pred OR not lhs == rhs
*mask_val || !(same_val)
mask_val == same_val
})
});

Expand All @@ -289,6 +337,14 @@ pub trait Chomper {
lhs: term2.clone(),
rhs: term1.clone(),
});
self.add_conditional_rewrite(
egraph,
Sexp::from_str(pred).unwrap(),
term1.clone(),
term2.clone(),
);
// let's just add one.
break;
}
}
}
Expand Down Expand Up @@ -317,16 +373,21 @@ pub trait Chomper {
}

fn add_conditional_rewrite(&mut self, egraph: &mut EGraph, cond: Sexp, lhs: Sexp, rhs: Sexp) {
let pred = self.make_string_not_bad(cond.to_string().as_str());
let _pred = self.make_string_not_bad(cond.to_string().as_str());
let term1 = self.make_string_not_bad(lhs.to_string().as_str());
let term2 = self.make_string_not_bad(rhs.to_string().as_str());
println!(
"adding conditional rewrite: {} -> {} if {}",
term1, term2, _pred
);
println!("term2 has cvec: {:?}", self.interpret_term(&rhs));
egraph
.parse_and_run_program(
None,
format!(
r#"
(cond-equal {pred} {term1} {term2})
(cond-equal {pred} {term2} {term1})
(cond-equal {term1} {term2})
(cond-equal {term2} {term1})
"#
)
.as_str(),
Expand All @@ -339,11 +400,8 @@ pub trait Chomper {
// the productions that add `i` to the size of the program.
fn productions(&self) -> Workload;
fn atoms(&self) -> Workload;

fn make_preds(&self) -> Workload;

fn get_env(&self) -> &HashMap<String, Vec<Value<Self>>>;

fn validate_rule(&self, rule: Rule) -> ValidationResult;
fn interpret_term(&mut self, term: &ruler::enumo::Sexp) -> CVec<Self>;
fn interpret_pred(&mut self, term: &ruler::enumo::Sexp) -> Vec<bool>;
Expand Down
Loading

0 comments on commit 9da65bd

Please sign in to comment.