Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enumerate by size #19

Merged
merged 12 commits into from
Oct 24, 2024
184 changes: 137 additions & 47 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use egglog::EGraph;
use ruler::enumo::Pattern;
use ruler::{HashMap, HashSet, ValidationResult};
use utils::TERM_PLACEHOLDER;

use std::fmt::Debug;
use std::hash::Hash;
use std::str::FromStr;

use ruler::enumo::{Sexp, Workload};
use ruler::enumo::{Filter, Metric, Sexp, Workload};

use log::info;

pub mod utils;

pub type Constant<R> = <R as Chomper>::Constant;
pub type CVec<R> = Vec<Option<<R as Chomper>::Constant>>;
pub type Value<R> = <R as Chomper>::Value;
Expand All @@ -24,6 +28,8 @@ pub struct Rules {
pub conditional: Vec<Rule>,
}

pub const MAX_SIZE: usize = 30;

#[macro_export]
macro_rules! init_egraph {
($egraph:ident, $path:expr) => {
Expand Down Expand Up @@ -62,41 +68,63 @@ pub trait Chomper {
fn run_chompy(
&mut self,
egraph: &mut EGraph,
test_name: &str,
rules: Vec<Rule>,
atoms: &Workload,
mask_to_preds: &HashMap<Vec<bool>, HashSet<String>>,
) {
let mut found: Vec<bool> = vec![false; rules.len()];

let mut old_workload = atoms.clone();
let mut max_eclass_id = 0;

const MAX_ITERATIONS: usize = 2;
for _ in 0..MAX_ITERATIONS {
let new_workload = self.make_terms(&old_workload);
old_workload = new_workload.clone();
info!(
"{}: new workload has {} terms",
test_name,
new_workload.force().len()
for current_size in 0..MAX_SIZE {
info!("adding programs of size {}:", current_size);

let mut filter = Filter::MetricEq(Metric::Atoms, current_size);
if current_size > 15 {
filter = Filter::And(vec![filter, Filter::Excludes(self.get_constant_pattern())]);
}

info!("finding eclass term map...");
let eclass_term_map = self
.reset_eclass_term_map(egraph)
.values()
.cloned()
.collect::<Vec<_>>();
info!("eclass term map len: {}", eclass_term_map.len());
let term_workload = Workload::new(
eclass_term_map
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>(),
);

let new_workload = if term_workload.force().is_empty() {
self.atoms().clone()
} else {
self.productions()
.clone()
.plug(TERM_PLACEHOLDER, &term_workload)
.filter(filter)
};

info!("new workload len: {}", new_workload.force().len());

for term in &new_workload.force() {
let term_string = self.make_string_not_bad(term.to_string().as_str());
egraph
.parse_and_run_program(
None,
format!(
r#"
{term_string}
(set (eclass {term_string}) {max_eclass_id})
{term_string}
(set (eclass {term_string}) {max_eclass_id})
"#
)
.as_str(),
)
.unwrap();
max_eclass_id += 1;
}

loop {
info!("starting cvec match");
let vals = self.cvec_match(egraph, mask_to_preds);
Expand All @@ -106,42 +134,83 @@ pub trait Chomper {
break;
}

// loop through conditionals and non-conditionals
for val in vals.conditional.iter().chain(vals.non_conditional.iter()) {
for (i, rule) in rules.iter().enumerate() {
if val.lhs == rule.lhs
&& val.rhs == rule.rhs
&& val.condition == rule.condition
{
found[i] = true;
if found.iter().all(|x| *x) {
return;
}
}
for (i, rule) in rules.iter().enumerate() {
let lhs = self.make_string_not_bad(rule.lhs.to_string().as_str());
let rhs = self.make_string_not_bad(rule.rhs.to_string().as_str());
if (rule.condition.is_some()
&& egraph
.parse_and_run_program(
None,
format!(
r#"
(check (cond-equal {lhs} {rhs}))
"#
)
.as_str(),
)
.is_ok())
|| (rule.condition.is_none()
&& egraph
.parse_and_run_program(
None,
format!(
r#"
(check (= {lhs} {rhs}))
"#
)
.as_str(),
)
.is_ok())
{
found[i] = true;
}
if found.iter().all(|x| *x) {
return;
}
}

for val in &vals.non_conditional {
let lhs = self.make_string_not_bad(val.lhs.to_string().as_str());
let rhs = self.make_string_not_bad(val.rhs.to_string().as_str());
if egraph
.parse_and_run_program(
None,
format!(
r#"
{lhs}
{rhs}
(check (= {lhs} {rhs}))
"#
)
.as_str(),
)
.is_err()
{
self.add_rewrite(egraph, val.lhs.clone(), val.rhs.clone());
};
}

for val in &vals.conditional {
self.add_conditional_rewrite(
egraph,
val.condition.clone().unwrap().clone(),
val.condition.clone().unwrap(),
val.lhs.clone(),
val.rhs.clone(),
);
}
for val in &vals.non_conditional {
self.add_rewrite(egraph, val.lhs.clone(), val.rhs.clone());
}

egraph
.parse_and_run_program(
None,
r#"
(run-schedule
(saturate non-cond-rewrites))"#,
(saturate non-cond-rewrites))
"#,
)
.unwrap();
}
}

panic!("not all rules were found");
}

Expand Down Expand Up @@ -192,6 +261,7 @@ pub trait Chomper {
}
for ec2 in ec_keys.iter().skip(i + 1) {
let term2 = eclass_term_map.get(ec2).unwrap();

let cvec2 = self.interpret_term(term2);

if cvec2.iter().all(|x| x.is_none()) {
Expand All @@ -210,7 +280,18 @@ pub trait Chomper {
rhs: term1.clone(),
});
} else {
// TODO: check if they are conditionally equal
if egraph
.parse_and_run_program(
None,
format!("(check (cond-equal {term1} {term2}))").as_str(),
)
.is_ok()
{
// TODO: we're going to ignore multiple conditionals for now, there are too many.
info!("skipping");
continue;
}

let mut has_meaningful_diff = false;
let mut matching_count = 0;
let mut same_vals: Vec<bool> = vec![];
Expand All @@ -221,30 +302,33 @@ pub trait Chomper {
has_meaningful_diff = true;
}
same_vals.push(has_match);
matching_count += 1;
if has_match {
matching_count += 1;
}
}

if !has_meaningful_diff {
continue;
}

// filter out bad predicates that only match on one value
if matching_count < 2 {
if matching_count == 1 {
continue;
}

// we want sufficient conditions, not sufficent and necessary.
// sufficient and necessary conditions.
// we may want to experiment with just having sufficient conditions.
let masks = mask_to_preds.keys().filter(|mask| {
mask.iter()
.zip(same_vals.iter())
.all(|(mask_val, same_val)| {
// pred --> lhs == rhs
// pred OR not lhs == rhs
*mask_val || !(same_val)
})
.all(|(mask_val, same_val)| mask_val == same_val)
});

for mask in masks {
// if the mask is completely false, skip it.
if mask.iter().all(|x| !x) {
continue;
}
let preds = mask_to_preds.get(mask).unwrap();
for pred in preds {
result.conditional.push(Rule {
Expand Down Expand Up @@ -285,30 +369,36 @@ pub trait Chomper {
}

fn add_conditional_rewrite(&mut self, egraph: &mut EGraph, cond: Sexp, lhs: Sexp, rhs: Sexp) {
let pred = self.make_string_not_bad(cond.to_string().as_str());
// TODO: @ninehusky: let's brainstorm ways to encode conditional equality with respect to a
// specific condition (see #20).
let _pred = self.make_string_not_bad(cond.to_string().as_str());
ninehusky marked this conversation as resolved.
Show resolved Hide resolved
let term1 = self.make_string_not_bad(lhs.to_string().as_str());
let term2 = self.make_string_not_bad(rhs.to_string().as_str());
info!(
"adding conditional rewrite: {} -> {} if {}",
term1, term2, _pred
);
info!("term2 has cvec: {:?}", self.interpret_term(&rhs));
egraph
.parse_and_run_program(
None,
format!(
r#"
(cond-equal {pred} {term1} {term2})
(cond-equal {pred} {term2} {term1})
(cond-equal {term1} {term2})
(cond-equal {term2} {term1})
"#
)
.as_str(),
)
.unwrap();
}

// applies the given productions to the old terms to get some new workload
fn make_terms(&self, old_terms: &Workload) -> Workload;
fn productions(&self) -> Workload;
fn atoms(&self) -> Workload;
fn make_preds(&self) -> Workload;

fn get_env(&self) -> &HashMap<String, Vec<Value<Self>>>;

fn validate_rule(&self, rule: Rule) -> ValidationResult;
fn interpret_term(&mut self, term: &ruler::enumo::Sexp) -> CVec<Self>;
fn interpret_pred(&mut self, term: &ruler::enumo::Sexp) -> Vec<bool>;
fn get_constant_pattern(&self) -> Pattern;
}
31 changes: 31 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use ruler::enumo::Sexp;

// Atoms with this name will `not` count in a production
// toward its size.
pub const TERM_PLACEHOLDER: &str = "?term";

pub fn get_production_size(term: &Sexp) -> usize {
get_size(term, true)
}

pub fn get_ast_size(term: &Sexp) -> usize {
get_size(term, false)
}

fn get_size(term: &Sexp, skip_placeholders: bool) -> usize {
match term {
Sexp::Atom(atom) => {
if skip_placeholders && atom == TERM_PLACEHOLDER {
return 0;
}
1
}
Sexp::List(list) => {
let mut size = 0;
for item in list {
size += get_size(item, skip_placeholders);
}
size
}
}
}
Loading