From 9ddddf84724800f0df922cbe956f72f6281c3396 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Pedro=20Henrique?= Date: Sat, 27 Jul 2024 01:46:07 -0300 Subject: [PATCH] perf: Remove expensive call to find matching in children (#66) During the merge process, given two nodes, left and right, we need to iterate over their children to determine how to merge them. In each iteration, we try to find a match for the current left child among the right children and vice versa. Previously, we were examining all captured matches across the entire program tree, which was heavily inefficient. This PR changes the implementation to only consider the matches captured in the other node's children, significantly reducing the number of iterations required. Initial observations in empirical scenarios showed a performance boost of nearly 3x. --- matching/src/matchings.rs | 29 ++++++++++++++++++----------- merge/src/ordered_merge.rs | 10 ++++++---- merge/src/unordered_merge.rs | 8 +++++--- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/matching/src/matchings.rs b/matching/src/matchings.rs index e100740..cf90f81 100644 --- a/matching/src/matchings.rs +++ b/matching/src/matchings.rs @@ -19,9 +19,9 @@ impl<'a> Matchings<'a> { } pub fn from_single(key: UnorderedPair<&'a CSTNode>, value: MatchingEntry) -> Self { - let mut matching_entries = HashMap::new(); - matching_entries.insert(key, value); - Matchings { matching_entries } + Matchings { + matching_entries: HashMap::from([(key, value)]), + } } pub fn new(matching_entries: HashMap>, MatchingEntry>) -> Self { @@ -48,20 +48,27 @@ impl<'a> Matchings<'a> { }) } - pub fn get_matching_entry( + pub fn find_matching_node_in_children( &'a self, - left: &'a CSTNode<'a>, - right: &'a CSTNode<'a>, - ) -> Option<&MatchingEntry> { - self.matching_entries.get(&UnorderedPair(left, right)) + a_node: &'a CSTNode<'a>, + children: &'a [CSTNode<'a>], + ) -> Option { + children.iter().find_map(|child| { + let matching_entry = self.get_matching_entry(child, a_node)?; + Some(Matching { + matching_node: child, + score: matching_entry.score, + is_perfect_match: matching_entry.is_perfect_match, + }) + }) } - pub fn has_bidirectional_matching( + pub fn get_matching_entry( &'a self, left: &'a CSTNode<'a>, right: &'a CSTNode<'a>, - ) -> bool { - self.find_matching_for(left).is_some() && self.find_matching_for(right).is_some() + ) -> Option<&MatchingEntry> { + self.matching_entries.get(&UnorderedPair(left, right)) } pub fn extend(&mut self, matchings: Matchings<'a>) { diff --git a/merge/src/ordered_merge.rs b/merge/src/ordered_merge.rs index 32b9b88..ea9ece4 100644 --- a/merge/src/ordered_merge.rs +++ b/merge/src/ordered_merge.rs @@ -30,10 +30,12 @@ pub fn ordered_merge<'a>( while let (Some(cur_left), Some(cur_right)) = (cur_left_option, cur_right_option) { let matching_base_left = base_left_matchings.find_matching_for(cur_left); let matching_base_right = base_right_matchings.find_matching_for(cur_right); - let left_matching_in_right = left_right_matchings.find_matching_for(cur_left); - let right_matching_in_left = left_right_matchings.find_matching_for(cur_right); + let left_matching_in_right = + left_right_matchings.find_matching_node_in_children(cur_left, right.get_children()); + let right_matching_in_left = + left_right_matchings.find_matching_node_in_children(cur_right, left.get_children()); let has_bidirectional_matching_left_right = - left_right_matchings.has_bidirectional_matching(cur_left, cur_right); + left_matching_in_right.is_some() && right_matching_in_left.is_some(); match ( has_bidirectional_matching_left_right, @@ -213,7 +215,7 @@ mod tests { use std::{borrow::Cow, vec}; use matching::{ordered, Matchings}; - use model::{cst_node::NonTerminal, cst_node::Terminal, CSTNode, Language, Point}; + use model::{cst_node::NonTerminal, cst_node::Terminal, CSTNode, Point}; use crate::{MergeError, MergedCSTNode}; diff --git a/merge/src/unordered_merge.rs b/merge/src/unordered_merge.rs index 19cd895..8b7facd 100644 --- a/merge/src/unordered_merge.rs +++ b/merge/src/unordered_merge.rs @@ -39,7 +39,8 @@ pub fn unordered_merge<'a>( } let matching_base_left = base_left_matchings.find_matching_for(left_child); - let matching_left_right = left_right_matchings.find_matching_for(left_child); + let matching_left_right = + left_right_matchings.find_matching_node_in_children(left_child, right.get_children()); match (matching_base_left, matching_left_right) { // Added only by left @@ -91,7 +92,8 @@ pub fn unordered_merge<'a>( .filter(|node| !processed_nodes.contains(&node.id())) { let matching_base_right = base_right_matchings.find_matching_for(right_child); - let matching_left_right = left_right_matchings.find_matching_for(right_child); + let matching_left_right = + left_right_matchings.find_matching_node_in_children(right_child, left.get_children()); match (matching_base_right, matching_left_right) { // Added only by right @@ -142,7 +144,7 @@ mod tests { use matching::{unordered::calculate_matchings, Matchings}; use model::{ cst_node::{NonTerminal, Terminal}, - CSTNode, Language, Point, + CSTNode, Point, }; use crate::{MergeError, MergedCSTNode};