diff --git a/matching/src/matchings.rs b/matching/src/matchings.rs index cf90f81..840f98d 100644 --- a/matching/src/matchings.rs +++ b/matching/src/matchings.rs @@ -8,58 +8,47 @@ use crate::matching_entry::MatchingEntry; #[derive(Debug, Clone)] pub struct Matchings<'a> { - pub matching_entries: HashMap>, MatchingEntry>, + matching_entries: HashMap>, MatchingEntry>, + individual_matchings: HashMap<&'a CSTNode<'a>, &'a CSTNode<'a>>, } impl<'a> Matchings<'a> { pub fn empty() -> Self { Matchings { matching_entries: HashMap::new(), + individual_matchings: HashMap::new(), } } pub fn from_single(key: UnorderedPair<&'a CSTNode>, value: MatchingEntry) -> Self { Matchings { matching_entries: HashMap::from([(key, value)]), + individual_matchings: HashMap::from([(key.0, key.1), (key.1, key.0)]), } } pub fn new(matching_entries: HashMap>, MatchingEntry>) -> Self { - Matchings { matching_entries } + Matchings { + individual_matchings: { + matching_entries + .keys() + .into_iter() + .flat_map(|key| [(key.0, key.1), (key.1, key.0)]) + .collect::, &'a CSTNode<'a>>>() + }, + matching_entries, + } } pub fn find_matching_for(&self, a_node: &'a CSTNode) -> Option { - self.matching_entries - .iter() - .find(|(UnorderedPair(left, right), ..)| { - left.id() == a_node.id() || right.id() == a_node.id() - }) - .map(|(UnorderedPair(left, right), matching)| { - let matching_node = if left.id() == a_node.id() { - right - } else { - left - }; - Matching { - matching_node, - score: matching.score, - is_perfect_match: matching.is_perfect_match, - } - }) - } - - pub fn find_matching_node_in_children( - &'a self, - a_node: &'a CSTNode<'a>, - children: &'a [CSTNode<'a>], - ) -> Option { - children.iter().find_map(|child| { - let matching_entry = self.get_matching_entry(child, a_node)?; - Some(Matching { - matching_node: child, - score: matching_entry.score, - is_perfect_match: matching_entry.is_perfect_match, - }) + let matching_node = self.individual_matchings.get(a_node)?; + let matching_entry = self + .matching_entries + .get(&UnorderedPair(a_node, matching_node))?; + Some(Matching { + matching_node, + score: matching_entry.score, + is_perfect_match: matching_entry.is_perfect_match, }) } @@ -72,6 +61,14 @@ impl<'a> Matchings<'a> { } pub fn extend(&mut self, matchings: Matchings<'a>) { + self.individual_matchings.extend( + matchings + .matching_entries + .keys() + .into_iter() + .flat_map(|key| [(key.0, key.1), (key.1, key.0)]) + .collect::, &'a CSTNode<'a>>>(), + ); self.matching_entries.extend(matchings); } } diff --git a/merge/src/ordered_merge.rs b/merge/src/ordered_merge.rs index ea9ece4..30a8606 100644 --- a/merge/src/ordered_merge.rs +++ b/merge/src/ordered_merge.rs @@ -30,10 +30,8 @@ pub fn ordered_merge<'a>( while let (Some(cur_left), Some(cur_right)) = (cur_left_option, cur_right_option) { let matching_base_left = base_left_matchings.find_matching_for(cur_left); let matching_base_right = base_right_matchings.find_matching_for(cur_right); - let left_matching_in_right = - left_right_matchings.find_matching_node_in_children(cur_left, right.get_children()); - let right_matching_in_left = - left_right_matchings.find_matching_node_in_children(cur_right, left.get_children()); + let left_matching_in_right = left_right_matchings.find_matching_for(cur_left); + let right_matching_in_left = left_right_matchings.find_matching_for(cur_right); let has_bidirectional_matching_left_right = left_matching_in_right.is_some() && right_matching_in_left.is_some(); diff --git a/merge/src/unordered_merge.rs b/merge/src/unordered_merge.rs index 8b7facd..95c77c6 100644 --- a/merge/src/unordered_merge.rs +++ b/merge/src/unordered_merge.rs @@ -39,8 +39,7 @@ pub fn unordered_merge<'a>( } let matching_base_left = base_left_matchings.find_matching_for(left_child); - let matching_left_right = - left_right_matchings.find_matching_node_in_children(left_child, right.get_children()); + let matching_left_right = left_right_matchings.find_matching_for(left_child); match (matching_base_left, matching_left_right) { // Added only by left @@ -92,8 +91,7 @@ pub fn unordered_merge<'a>( .filter(|node| !processed_nodes.contains(&node.id())) { let matching_base_right = base_right_matchings.find_matching_for(right_child); - let matching_left_right = - left_right_matchings.find_matching_node_in_children(right_child, left.get_children()); + let matching_left_right = left_right_matchings.find_matching_for(right_child); match (matching_base_right, matching_left_right) { // Added only by right