diff --git a/matching/src/matchings.rs b/matching/src/matchings.rs index e100740..cf90f81 100644 --- a/matching/src/matchings.rs +++ b/matching/src/matchings.rs @@ -19,9 +19,9 @@ impl<'a> Matchings<'a> { } pub fn from_single(key: UnorderedPair<&'a CSTNode>, value: MatchingEntry) -> Self { - let mut matching_entries = HashMap::new(); - matching_entries.insert(key, value); - Matchings { matching_entries } + Matchings { + matching_entries: HashMap::from([(key, value)]), + } } pub fn new(matching_entries: HashMap>, MatchingEntry>) -> Self { @@ -48,20 +48,27 @@ impl<'a> Matchings<'a> { }) } - pub fn get_matching_entry( + pub fn find_matching_node_in_children( &'a self, - left: &'a CSTNode<'a>, - right: &'a CSTNode<'a>, - ) -> Option<&MatchingEntry> { - self.matching_entries.get(&UnorderedPair(left, right)) + a_node: &'a CSTNode<'a>, + children: &'a [CSTNode<'a>], + ) -> Option { + children.iter().find_map(|child| { + let matching_entry = self.get_matching_entry(child, a_node)?; + Some(Matching { + matching_node: child, + score: matching_entry.score, + is_perfect_match: matching_entry.is_perfect_match, + }) + }) } - pub fn has_bidirectional_matching( + pub fn get_matching_entry( &'a self, left: &'a CSTNode<'a>, right: &'a CSTNode<'a>, - ) -> bool { - self.find_matching_for(left).is_some() && self.find_matching_for(right).is_some() + ) -> Option<&MatchingEntry> { + self.matching_entries.get(&UnorderedPair(left, right)) } pub fn extend(&mut self, matchings: Matchings<'a>) { diff --git a/merge/src/ordered_merge.rs b/merge/src/ordered_merge.rs index 32b9b88..ea9ece4 100644 --- a/merge/src/ordered_merge.rs +++ b/merge/src/ordered_merge.rs @@ -30,10 +30,12 @@ pub fn ordered_merge<'a>( while let (Some(cur_left), Some(cur_right)) = (cur_left_option, cur_right_option) { let matching_base_left = base_left_matchings.find_matching_for(cur_left); let matching_base_right = base_right_matchings.find_matching_for(cur_right); - let left_matching_in_right = left_right_matchings.find_matching_for(cur_left); - let right_matching_in_left = left_right_matchings.find_matching_for(cur_right); + let left_matching_in_right = + left_right_matchings.find_matching_node_in_children(cur_left, right.get_children()); + let right_matching_in_left = + left_right_matchings.find_matching_node_in_children(cur_right, left.get_children()); let has_bidirectional_matching_left_right = - left_right_matchings.has_bidirectional_matching(cur_left, cur_right); + left_matching_in_right.is_some() && right_matching_in_left.is_some(); match ( has_bidirectional_matching_left_right, @@ -213,7 +215,7 @@ mod tests { use std::{borrow::Cow, vec}; use matching::{ordered, Matchings}; - use model::{cst_node::NonTerminal, cst_node::Terminal, CSTNode, Language, Point}; + use model::{cst_node::NonTerminal, cst_node::Terminal, CSTNode, Point}; use crate::{MergeError, MergedCSTNode}; diff --git a/merge/src/unordered_merge.rs b/merge/src/unordered_merge.rs index 19cd895..8b7facd 100644 --- a/merge/src/unordered_merge.rs +++ b/merge/src/unordered_merge.rs @@ -39,7 +39,8 @@ pub fn unordered_merge<'a>( } let matching_base_left = base_left_matchings.find_matching_for(left_child); - let matching_left_right = left_right_matchings.find_matching_for(left_child); + let matching_left_right = + left_right_matchings.find_matching_node_in_children(left_child, right.get_children()); match (matching_base_left, matching_left_right) { // Added only by left @@ -91,7 +92,8 @@ pub fn unordered_merge<'a>( .filter(|node| !processed_nodes.contains(&node.id())) { let matching_base_right = base_right_matchings.find_matching_for(right_child); - let matching_left_right = left_right_matchings.find_matching_for(right_child); + let matching_left_right = + left_right_matchings.find_matching_node_in_children(right_child, left.get_children()); match (matching_base_right, matching_left_right) { // Added only by right @@ -142,7 +144,7 @@ mod tests { use matching::{unordered::calculate_matchings, Matchings}; use model::{ cst_node::{NonTerminal, Terminal}, - CSTNode, Language, Point, + CSTNode, Point, }; use crate::{MergeError, MergedCSTNode};