Skip to content

Commit

Permalink
args used thingy
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Feb 4, 2025
1 parent 075b512 commit 64a1dbe
Showing 1 changed file with 87 additions and 84 deletions.
171 changes: 87 additions & 84 deletions dag_in_context/src/greedy_dag_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ impl<'a> Extractor<'a> {
costs: Default::default(),
total: 0.0.try_into().unwrap(),
term,
args_used: Default::default(),
};
self.costsets.push(costset);
self.costsetmemo
Expand Down Expand Up @@ -409,6 +410,7 @@ pub struct CostSet {
/// Maps classes to the chosen term for the eclass,
/// along with the cost for that term (excluding child costs).
pub costs: HashTrieMap<ClassId, (Term, Cost)>,
pub args_used: HashSet<usize>,
/// The resulting term
pub term: Term,
}
Expand Down Expand Up @@ -488,10 +490,14 @@ impl<'a> Extractor<'a> {
if let Some((existing_term, _existing_cost)) = current_costs.get(eclass) {
(existing_term.clone(), NotNan::new(0.).unwrap())
} else {
let unshared_cost = match other_costs.get(eclass) {
Some((_, cost)) => *cost,
// no cost stored, so it's free
None => NotNan::new(0.).unwrap(),
let unshared_cost = if is_free {
NotNan::new(0.).unwrap()
} else {
match other_costs.get(eclass) {
Some((_, cost)) => *cost,
// no cost stored, so it's free
None => NotNan::new(0.).unwrap(),
}
};

let mut cost = unshared_cost;
Expand Down Expand Up @@ -573,47 +579,6 @@ impl<'a> Extractor<'a> {
}
}

/// Given a term, returns what indices of the argument are used in the term
/// Returns None if the type is not a tuple or arg is used directly
fn get_arg_indices_used(&self, term: Term, used: &mut HashSet<usize>) -> Option<()> {
match &term {
Term::App(head, children) => {
if head.to_string() == "Arg" {
None
} else if head.to_string() == "Get" {
// check if we are getting an arg
let child = self.termdag.get(children[0]).clone();
match &child {
Term::App(head, _arg_children) => {
if head.to_string() == "Arg" {
// now we only care about the index
let Term::Lit(Literal::Int(lit)) = self.termdag.get(children[1])
else {
panic!(
"Expected literal in Get index, got {:?}",
self.termdag.get(children[1])
);
};
used.insert(*lit as usize);
Some(())
} else {
self.get_arg_indices_used(child, used)
}
}
_ => self.get_arg_indices_used(child, used),
}
} else {
for child in children {
self.get_arg_indices_used(self.termdag.get(*child).clone(), used)?;
}
Some(())
}
}
Term::Var(_) => panic!("Found variable in term during extraction"),
Term::Lit(_l) => Some(()),
}
}

/// Replaces the leafs of the model_term with children
/// Also adds to the `correspondence` map based on the model term.
fn build_concat(&mut self, model_term: Term, children: &Vec<Term>) -> (Term, usize) {
Expand Down Expand Up @@ -702,6 +667,7 @@ impl<'a> Extractor<'a> {

let mut shared_total = NotNan::new(0.).unwrap();
let mut unshared_total = info.cm.get_op_cost(&node.op);
let mut args_used = HashSet::new();

// special case: when the call is recursive, set super high cost
if node.op == "Call" {
Expand All @@ -722,63 +688,61 @@ impl<'a> Extractor<'a> {

if !info.cm.ignore_children(&node.op) {
for (child_set, enode_child) in child_cost_sets.iter() {
let mut add_to_shared = false;
if enode_child.is_subregion {
children_terms.push(child_set.term.clone());
let (mut new_child, should_add) = if enode_child.is_subregion {
unshared_total += self.subregion_cost(info, nodeid.clone(), child_set);
(child_set.term.clone(), false)
} else if enode_child.is_if_inputs {
// special case- try to only add cost for inputs that are used

// first, get all the indices of the children that are used
let mut used_children = HashSet::new();
let mut used_children: HashSet<usize> = HashSet::new();
for (child_set, enode_child) in child_cost_sets.iter() {
if enode_child.is_subregion {
if let Some(()) = self
.get_arg_indices_used(child_set.term.clone(), &mut used_children)
{
// keep going
} else {
add_to_shared = true;
}
used_children.extend(child_set.args_used.iter());
}
}

if !add_to_shared {
// now that we have which children are used, try to break up the inputs
if let Some(broken_up_terms) = self.try_break_up_term(&child_set.term) {
let mut new_input_children = vec![];
for (idx, input_tuple_term) in broken_up_terms.iter().enumerate() {
let (child_term, net_cost) = self.add_term_to_cost_set(
info,
&mut costs,
input_tuple_term.clone(),
&child_set.costs,
!used_children.contains(&idx),
);
shared_total += net_cost;
new_input_children.push(child_term);
}
let (new_term, children_used) =
self.build_concat(child_set.term.clone(), &new_input_children);
assert_eq!(children_used, new_input_children.len());
children_terms.push(new_term);
} else {
add_to_shared = true;
// now that we have which children are used, try to break up the inputs
if let Some(broken_up_terms) = self.try_break_up_term(&child_set.term) {
let mut new_input_children = vec![];
for (idx, input_tuple_term) in broken_up_terms.iter().enumerate() {
let (child_term, net_cost) = self.add_term_to_cost_set(
info,
&mut costs,
input_tuple_term.clone(),
&child_set.costs,
!used_children.contains(&idx),
);
shared_total += net_cost;
new_input_children.push(child_term);
}
let (new_term, children_used) =
self.build_concat(child_set.term.clone(), &new_input_children);
assert_eq!(children_used, new_input_children.len());
(new_term, false)
} else {
(child_set.term.clone(), true)
}
} else {
add_to_shared = true;
}
(child_set.term.clone(), true)
};

if add_to_shared {
let (child_term, net_cost) = self.add_term_to_cost_set(
if should_add {
let (new_new_child_term, net_cost) = self.add_term_to_cost_set(
info,
&mut costs,
child_set.term.clone(),
new_child.clone(),
&child_set.costs,
false,
);
shared_total += net_cost;
children_terms.push(child_term);
new_child = new_new_child_term;
}
children_terms.push(new_child);

// if it's not a subregion, add to args_used
if !enode_child.is_subregion {
args_used.extend(child_set.args_used.iter());
}
}
}
Expand All @@ -791,10 +755,49 @@ impl<'a> Extractor<'a> {
}
let total = unshared_total + shared_total;

// for an argument, add all indicies
if node.op == "Arg" {
// first argument is type
let ty = self.typecheck_term(&term);
match ty {
Type::TupleT(base_types) => {
for i in 0..base_types.len() {
args_used.insert(i);
}
}
_ => (),
}
}

// for a get of an arg, clear args used except for the one used
if node.op == "Get" {
let arg_term = &child_cost_sets[0].0.term;
match arg_term {
Term::App(symbol, _items) => {
if symbol.to_string() == "Arg" {
let arg_index = child_cost_sets[1].0.term.clone();
match arg_index {
Term::Lit(Literal::Int(i)) => {
args_used.clear();
args_used.insert(i as usize);
}
_ => panic!("Unexpected term in Get: {:?}", arg_term),
}
}
}
_ => panic!("Unexpected term in Get: {:?}", arg_term),
}
}

// swap borrowed costsets back!
std::mem::swap(&mut self.costsets, &mut cost_sets_tmp);

self.costsets.push(CostSet { total, costs, term });
self.costsets.push(CostSet {
total,
costs,
term,
args_used,
});
let index = self.costsets.len() - 1;
self.costsetmemo
.insert((nodeid, child_cost_set_indicies), index);
Expand Down

0 comments on commit 64a1dbe

Please sign in to comment.