From 39687cefff7b3139f33ae201347ba49618646541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Pedro=20Henrique?= Date: Sat, 15 Jun 2024 16:54:53 -0300 Subject: [PATCH] fix(matching): Incorrect calculation of perfect node matching (#58) Previously, we were only running an equality check between both nodes contents. Now, we rely on the property that two nodes are equal if their matching score equals the sum of their trees' sizes. This makes the check more accurate and more aligned with the strategies we're using in the tool. --- Cargo.lock | 1 + matching/Cargo.toml | 1 + matching/src/lib.rs | 23 ++++---- matching/src/matching_entry.rs | 6 ++- matching/src/matchings.rs | 5 +- matching/src/ordered/mod.rs | 53 ++++++++----------- matching/src/unordered/assignment_problem.rs | 5 +- matching/src/unordered/unique_label.rs | 5 +- matching/tests/perfect_matching.rs | 55 ++++++++++++++++++++ model/src/cst_node.rs | 16 ++++++ 10 files changed, 115 insertions(+), 55 deletions(-) create mode 100644 matching/tests/perfect_matching.rs diff --git a/Cargo.lock b/Cargo.lock index 12af7ea..46ab217 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -343,6 +343,7 @@ dependencies = [ "log", "matching_handlers", "model", + "parsing", "pathfinding", "unordered-pair", "uuid", diff --git a/matching/Cargo.toml b/matching/Cargo.toml index 8058d77..6f00aa7 100644 --- a/matching/Cargo.toml +++ b/matching/Cargo.toml @@ -14,3 +14,4 @@ pathfinding = "4.9.1" [dev-dependencies] uuid = { workspace = true } +parsing = { path = "../parsing" } diff --git a/matching/src/lib.rs b/matching/src/lib.rs index bda7aa1..6aa2842 100644 --- a/matching/src/lib.rs +++ b/matching/src/lib.rs @@ -47,21 +47,18 @@ pub fn calculate_matchings<'a>( let is_perfect_match = kind_left == kind_right && value_left == value_right; Matchings::from_single( UnorderedPair(left, right), - MatchingEntry::new(is_perfect_match.into(), is_perfect_match), + MatchingEntry::new(left, right, is_perfect_match.into()), ) } - (_, _) => Matchings::from_single(UnorderedPair(left, right), MatchingEntry::new(0, false)), + (_, _) => Matchings::empty(), } } #[cfg(test)] mod tests { + use crate::{calculate_matchings, matching_configuration::MatchingConfiguration}; use model::{cst_node::Terminal, CSTNode, Point}; - use crate::{ - calculate_matchings, matching_configuration::MatchingConfiguration, MatchingEntry, - }; - #[test] fn two_terminal_nodes_matches_with_a_score_of_one_if_they_have_the_same_kind_and_value() { let left = CSTNode::Terminal(Terminal { @@ -84,10 +81,9 @@ mod tests { let matching_configuration = MatchingConfiguration::default(); let matchings = calculate_matchings(&left, &right, &matching_configuration); - assert_eq!( - Some(&MatchingEntry::new(1, true)), - matchings.get_matching_entry(&left, &right) - ) + let left_right_matching = matchings.get_matching_entry(&left, &right).unwrap(); + assert_eq!(1, left_right_matching.score); + assert!(left_right_matching.is_perfect_match); } #[test] @@ -112,9 +108,8 @@ mod tests { let matching_configuration = MatchingConfiguration::default(); let matchings = calculate_matchings(&left, &right, &matching_configuration); - assert_eq!( - Some(&MatchingEntry::new(0, false)), - matchings.get_matching_entry(&left, &right) - ) + let left_right_matching = matchings.get_matching_entry(&left, &right).unwrap(); + assert_eq!(0, left_right_matching.score); + assert!(!left_right_matching.is_perfect_match); } } diff --git a/matching/src/matching_entry.rs b/matching/src/matching_entry.rs index a2757e8..1f06b34 100644 --- a/matching/src/matching_entry.rs +++ b/matching/src/matching_entry.rs @@ -1,3 +1,5 @@ +use model::CSTNode; + #[derive(Clone, Debug, PartialEq, Eq)] pub struct MatchingEntry { pub score: usize, @@ -5,10 +7,10 @@ pub struct MatchingEntry { } impl MatchingEntry { - pub fn new(score: usize, is_perfect_match: bool) -> Self { + pub fn new(left: &CSTNode, right: &CSTNode, score: usize) -> Self { MatchingEntry { score, - is_perfect_match, + is_perfect_match: (2 * score) == (left.get_tree_size() + right.get_tree_size()), } } } diff --git a/matching/src/matchings.rs b/matching/src/matchings.rs index 2435fc0..e100740 100644 --- a/matching/src/matchings.rs +++ b/matching/src/matchings.rs @@ -118,7 +118,10 @@ mod tests { }); let mut matchings = HashMap::new(); - matchings.insert(UnorderedPair(&a_node, &a_node), MatchingEntry::new(1, true)); + matchings.insert( + UnorderedPair(&a_node, &a_node), + MatchingEntry::new(&a_node, &a_node, 1), + ); assert_eq!( Some(Matching { diff --git a/matching/src/ordered/mod.rs b/matching/src/ordered/mod.rs index b4b8810..cbe27bf 100644 --- a/matching/src/ordered/mod.rs +++ b/matching/src/ordered/mod.rs @@ -80,10 +80,7 @@ pub fn calculate_matchings<'a>( let mut matchings = Matchings::from_single( UnorderedPair(left, right), - MatchingEntry::new( - matrix_m[m][n] + root_matching, - left.contents() == right.contents(), - ), + MatchingEntry::new(left, right, matrix_m[m][n] + root_matching), ); while i >= 1 && j >= 1 { @@ -108,7 +105,7 @@ pub fn calculate_matchings<'a>( #[cfg(test)] mod tests { - use crate::{matching_entry::MatchingEntry, *}; + use crate::MatchingConfiguration; use model::{ cst_node::{NonTerminal, Terminal}, language, CSTNode, Language, Point, @@ -144,10 +141,10 @@ mod tests { let matching_configuration = MatchingConfiguration::default(); let matchings = super::calculate_matchings(&left, &right, &matching_configuration); - assert_eq!( - Some(&MatchingEntry::new(1, true)), - matchings.get_matching_entry(&child, &child) - ) + let child_matching = matchings.get_matching_entry(&child, &child); + assert!(child_matching.is_some()); + assert_eq!(1, child_matching.unwrap().score); + assert!(child_matching.unwrap().is_perfect_match) } #[test] @@ -188,11 +185,9 @@ mod tests { let matching_configuration = MatchingConfiguration::from(Language::Java); let matchings = super::calculate_matchings(&left, &right, &matching_configuration); - - assert_eq!( - None, - matchings.get_matching_entry(&left_child, &right_child) - ) + assert!(matchings + .get_matching_entry(&left_child, &right_child) + .is_none()) } #[test] @@ -234,10 +229,9 @@ mod tests { let matching_configuration = MatchingConfiguration::from(language::Language::Java); let matchings = super::calculate_matchings(&left, &right, &matching_configuration); - assert_eq!( - Some(&MatchingEntry::new(2, false)), - matchings.get_matching_entry(&left, &right) - ) + let left_right_matchings = matchings.get_matching_entry(&left, &right).unwrap(); + assert_eq!(2, left_right_matchings.score); + assert!(!left_right_matchings.is_perfect_match); } #[test] @@ -271,10 +265,9 @@ mod tests { let matching_configuration = MatchingConfiguration::from(language::Language::Java); let matchings = super::calculate_matchings(&left, &right, &matching_configuration); - assert_eq!( - Some(&MatchingEntry::new(2, true)), - matchings.get_matching_entry(&left, &right) - ) + let left_right_matchings = matchings.get_matching_entry(&left, &right).unwrap(); + assert_eq!(2, left_right_matchings.score); + assert!(left_right_matchings.is_perfect_match); } #[test] @@ -317,14 +310,14 @@ mod tests { let matching_configuration = MatchingConfiguration::default(); let matchings = super::calculate_matchings(&left, &right, &matching_configuration); - assert_eq!( - Some(&MatchingEntry::new(2, true)), - matchings.get_matching_entry(&intermediate, &intermediate) - ); + let intermediate_matching = matchings + .get_matching_entry(&intermediate, &intermediate) + .unwrap(); + assert_eq!(2, intermediate_matching.score); + assert!(intermediate_matching.is_perfect_match); - assert_eq!( - Some(&MatchingEntry::new(3, true)), - matchings.get_matching_entry(&left, &right) - ) + let left_right_matching = matchings.get_matching_entry(&left, &right).unwrap(); + assert_eq!(3, left_right_matching.score); + assert!(left_right_matching.is_perfect_match); } } diff --git a/matching/src/unordered/assignment_problem.rs b/matching/src/unordered/assignment_problem.rs index bba1ec0..eb0d973 100644 --- a/matching/src/unordered/assignment_problem.rs +++ b/matching/src/unordered/assignment_problem.rs @@ -84,10 +84,7 @@ fn solve_assignment_problem<'a>( result.extend(Matchings::from_single( UnorderedPair(left, right), - MatchingEntry { - score: max_matching as usize + 1, - is_perfect_match: left.contents() == right.contents(), - }, + MatchingEntry::new(left, right, max_matching as usize + 1), )); result diff --git a/matching/src/unordered/unique_label.rs b/matching/src/unordered/unique_label.rs index cafc136..5c961f8 100644 --- a/matching/src/unordered/unique_label.rs +++ b/matching/src/unordered/unique_label.rs @@ -51,10 +51,7 @@ pub fn calculate_matchings<'a>( result.extend(Matchings::from_single( UnorderedPair(left, right), - MatchingEntry { - score: sum + root_matching, - is_perfect_match: left.contents() == right.contents(), - }, + MatchingEntry::new(left, right, sum + root_matching), )); result diff --git a/matching/tests/perfect_matching.rs b/matching/tests/perfect_matching.rs new file mode 100644 index 0000000..4bffd4d --- /dev/null +++ b/matching/tests/perfect_matching.rs @@ -0,0 +1,55 @@ +use matching::matching_configuration::MatchingConfiguration; +use model::language::Language; +use parsing::ParserConfiguration; + +#[test] +fn the_perfect_matching_calculation_is_correct() -> Result<(), Box> { + let config = ParserConfiguration::from(Language::Java); + let left = parsing::parse_string( + r#""" + public class Main { + static { + int x = 2; + } + + public static void main() { + int a = 0; + } + + public static void teste() { + + } + } + """#, + &config, + )?; + + let right = parsing::parse_string( + r#""" + public class Main { + public static void teste() { + + } + static { + int x = 2; + } + + public static void main() { + int a = 0; + + } + } + """#, + &config, + )?; + + let matching_configuration = MatchingConfiguration::from(Language::Java); + let matchings = matching::calculate_matchings(&left, &right, &matching_configuration); + assert!( + matchings + .get_matching_entry(&left, &right) + .unwrap() + .is_perfect_match + ); + Ok(()) +} diff --git a/model/src/cst_node.rs b/model/src/cst_node.rs index e8c99f5..b615fb7 100644 --- a/model/src/cst_node.rs +++ b/model/src/cst_node.rs @@ -48,6 +48,22 @@ impl CSTNode<'_> { CSTNode::NonTerminal(node) => node.end_position, } } + + fn get_subtree_size(&self) -> usize { + match self { + CSTNode::Terminal(_) => 0, + CSTNode::NonTerminal(node) => node + .children + .iter() + .fold(node.children.len(), |acc, child| { + acc + child.get_subtree_size() + }), + } + } + + pub fn get_tree_size(&self) -> usize { + self.get_subtree_size() + 1 + } } #[derive(Debug, Default, Clone)]