diff --git a/dag_in_context/src/greedy_dag_extractor.rs b/dag_in_context/src/greedy_dag_extractor.rs index 51a7a572d..19bf7a5ac 100644 --- a/dag_in_context/src/greedy_dag_extractor.rs +++ b/dag_in_context/src/greedy_dag_extractor.rs @@ -335,6 +335,7 @@ impl<'a> Extractor<'a> { costs: Default::default(), total: 0.0.try_into().unwrap(), term, + args_used: Default::default(), }; self.costsets.push(costset); self.costsetmemo @@ -409,6 +410,7 @@ pub struct CostSet { /// Maps classes to the chosen term for the eclass, /// along with the cost for that term (excluding child costs). pub costs: HashTrieMap, + pub args_used: HashSet, /// The resulting term pub term: Term, } @@ -488,10 +490,14 @@ impl<'a> Extractor<'a> { if let Some((existing_term, _existing_cost)) = current_costs.get(eclass) { (existing_term.clone(), NotNan::new(0.).unwrap()) } else { - let unshared_cost = match other_costs.get(eclass) { - Some((_, cost)) => *cost, - // no cost stored, so it's free - None => NotNan::new(0.).unwrap(), + let unshared_cost = if is_free { + NotNan::new(0.).unwrap() + } else { + match other_costs.get(eclass) { + Some((_, cost)) => *cost, + // no cost stored, so it's free + None => NotNan::new(0.).unwrap(), + } }; let mut cost = unshared_cost; @@ -573,47 +579,6 @@ impl<'a> Extractor<'a> { } } - /// Given a term, returns what indices of the argument are used in the term - /// Returns None if the type is not a tuple or arg is used directly - fn get_arg_indices_used(&self, term: Term, used: &mut HashSet) -> Option<()> { - match &term { - Term::App(head, children) => { - if head.to_string() == "Arg" { - None - } else if head.to_string() == "Get" { - // check if we are getting an arg - let child = self.termdag.get(children[0]).clone(); - match &child { - Term::App(head, _arg_children) => { - if head.to_string() == "Arg" { - // now we only care about the index - let Term::Lit(Literal::Int(lit)) = self.termdag.get(children[1]) - else { - panic!( - "Expected literal in Get index, got {:?}", - self.termdag.get(children[1]) - ); - }; - used.insert(*lit as usize); - Some(()) - } else { - self.get_arg_indices_used(child, used) - } - } - _ => self.get_arg_indices_used(child, used), - } - } else { - for child in children { - self.get_arg_indices_used(self.termdag.get(*child).clone(), used)?; - } - Some(()) - } - } - Term::Var(_) => panic!("Found variable in term during extraction"), - Term::Lit(_l) => Some(()), - } - } - /// Replaces the leafs of the model_term with children /// Also adds to the `correspondence` map based on the model term. fn build_concat(&mut self, model_term: Term, children: &Vec) -> (Term, usize) { @@ -702,6 +667,7 @@ impl<'a> Extractor<'a> { let mut shared_total = NotNan::new(0.).unwrap(); let mut unshared_total = info.cm.get_op_cost(&node.op); + let mut args_used = HashSet::new(); // special case: when the call is recursive, set super high cost if node.op == "Call" { @@ -722,63 +688,61 @@ impl<'a> Extractor<'a> { if !info.cm.ignore_children(&node.op) { for (child_set, enode_child) in child_cost_sets.iter() { - let mut add_to_shared = false; - if enode_child.is_subregion { - children_terms.push(child_set.term.clone()); + let (mut new_child, should_add) = if enode_child.is_subregion { unshared_total += self.subregion_cost(info, nodeid.clone(), child_set); + (child_set.term.clone(), false) } else if enode_child.is_if_inputs { // special case- try to only add cost for inputs that are used + // first, get all the indices of the children that are used - let mut used_children = HashSet::new(); + let mut used_children: HashSet = HashSet::new(); for (child_set, enode_child) in child_cost_sets.iter() { if enode_child.is_subregion { - if let Some(()) = self - .get_arg_indices_used(child_set.term.clone(), &mut used_children) - { - // keep going - } else { - add_to_shared = true; - } + used_children.extend(child_set.args_used.iter()); } } - if !add_to_shared { - // now that we have which children are used, try to break up the inputs - if let Some(broken_up_terms) = self.try_break_up_term(&child_set.term) { - let mut new_input_children = vec![]; - for (idx, input_tuple_term) in broken_up_terms.iter().enumerate() { - let (child_term, net_cost) = self.add_term_to_cost_set( - info, - &mut costs, - input_tuple_term.clone(), - &child_set.costs, - !used_children.contains(&idx), - ); - shared_total += net_cost; - new_input_children.push(child_term); - } - let (new_term, children_used) = - self.build_concat(child_set.term.clone(), &new_input_children); - assert_eq!(children_used, new_input_children.len()); - children_terms.push(new_term); - } else { - add_to_shared = true; + // now that we have which children are used, try to break up the inputs + if let Some(broken_up_terms) = self.try_break_up_term(&child_set.term) { + let mut new_input_children = vec![]; + for (idx, input_tuple_term) in broken_up_terms.iter().enumerate() { + let (child_term, net_cost) = self.add_term_to_cost_set( + info, + &mut costs, + input_tuple_term.clone(), + &child_set.costs, + !used_children.contains(&idx), + ); + shared_total += net_cost; + new_input_children.push(child_term); } + let (new_term, children_used) = + self.build_concat(child_set.term.clone(), &new_input_children); + assert_eq!(children_used, new_input_children.len()); + (new_term, false) + } else { + (child_set.term.clone(), true) } } else { - add_to_shared = true; - } + (child_set.term.clone(), true) + }; - if add_to_shared { - let (child_term, net_cost) = self.add_term_to_cost_set( + if should_add { + let (new_new_child_term, net_cost) = self.add_term_to_cost_set( info, &mut costs, - child_set.term.clone(), + new_child.clone(), &child_set.costs, false, ); shared_total += net_cost; - children_terms.push(child_term); + new_child = new_new_child_term; + } + children_terms.push(new_child); + + // if it's not a subregion, add to args_used + if !enode_child.is_subregion { + args_used.extend(child_set.args_used.iter()); } } } @@ -791,10 +755,49 @@ impl<'a> Extractor<'a> { } let total = unshared_total + shared_total; + // for an argument, add all indicies + if node.op == "Arg" { + // first argument is type + let ty = self.typecheck_term(&term); + match ty { + Type::TupleT(base_types) => { + for i in 0..base_types.len() { + args_used.insert(i); + } + } + _ => (), + } + } + + // for a get of an arg, clear args used except for the one used + if node.op == "Get" { + let arg_term = &child_cost_sets[0].0.term; + match arg_term { + Term::App(symbol, _items) => { + if symbol.to_string() == "Arg" { + let arg_index = child_cost_sets[1].0.term.clone(); + match arg_index { + Term::Lit(Literal::Int(i)) => { + args_used.clear(); + args_used.insert(i as usize); + } + _ => panic!("Unexpected term in Get: {:?}", arg_term), + } + } + } + _ => panic!("Unexpected term in Get: {:?}", arg_term), + } + } + // swap borrowed costsets back! std::mem::swap(&mut self.costsets, &mut cost_sets_tmp); - self.costsets.push(CostSet { total, costs, term }); + self.costsets.push(CostSet { + total, + costs, + term, + args_used, + }); let index = self.costsets.len() - 1; self.costsetmemo .insert((nodeid, child_cost_set_indicies), index);