Skip to content

Commit

Permalink
perf: Remove expensive call to find matching in children (#66)
Browse files Browse the repository at this point in the history
During the merge process, given two nodes, left and right, we need to
iterate over their children to determine how to merge them. In each
iteration, we try to find a match for the current left child among the
right children and vice versa. Previously, we were examining all
captured matches across the entire program tree, which was heavily
inefficient.

This PR changes the implementation to only consider the matches captured
in the other node's children, significantly reducing the number of
iterations required. Initial observations in empirical scenarios showed
a performance boost of nearly 3x.
  • Loading branch information
jpedroh authored Jul 27, 2024
1 parent 2d6141e commit 9ddddf8
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
29 changes: 18 additions & 11 deletions matching/src/matchings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ impl<'a> Matchings<'a> {
}

pub fn from_single(key: UnorderedPair<&'a CSTNode>, value: MatchingEntry) -> Self {
let mut matching_entries = HashMap::new();
matching_entries.insert(key, value);
Matchings { matching_entries }
Matchings {
matching_entries: HashMap::from([(key, value)]),
}
}

pub fn new(matching_entries: HashMap<UnorderedPair<&'a CSTNode<'a>>, MatchingEntry>) -> Self {
Expand All @@ -48,20 +48,27 @@ impl<'a> Matchings<'a> {
})
}

pub fn get_matching_entry(
pub fn find_matching_node_in_children(
&'a self,
left: &'a CSTNode<'a>,
right: &'a CSTNode<'a>,
) -> Option<&MatchingEntry> {
self.matching_entries.get(&UnorderedPair(left, right))
a_node: &'a CSTNode<'a>,
children: &'a [CSTNode<'a>],
) -> Option<Matching> {
children.iter().find_map(|child| {
let matching_entry = self.get_matching_entry(child, a_node)?;
Some(Matching {
matching_node: child,
score: matching_entry.score,
is_perfect_match: matching_entry.is_perfect_match,
})
})
}

pub fn has_bidirectional_matching(
pub fn get_matching_entry(
&'a self,
left: &'a CSTNode<'a>,
right: &'a CSTNode<'a>,
) -> bool {
self.find_matching_for(left).is_some() && self.find_matching_for(right).is_some()
) -> Option<&MatchingEntry> {
self.matching_entries.get(&UnorderedPair(left, right))
}

pub fn extend(&mut self, matchings: Matchings<'a>) {
Expand Down
10 changes: 6 additions & 4 deletions merge/src/ordered_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ pub fn ordered_merge<'a>(
while let (Some(cur_left), Some(cur_right)) = (cur_left_option, cur_right_option) {
let matching_base_left = base_left_matchings.find_matching_for(cur_left);
let matching_base_right = base_right_matchings.find_matching_for(cur_right);
let left_matching_in_right = left_right_matchings.find_matching_for(cur_left);
let right_matching_in_left = left_right_matchings.find_matching_for(cur_right);
let left_matching_in_right =
left_right_matchings.find_matching_node_in_children(cur_left, right.get_children());
let right_matching_in_left =
left_right_matchings.find_matching_node_in_children(cur_right, left.get_children());
let has_bidirectional_matching_left_right =
left_right_matchings.has_bidirectional_matching(cur_left, cur_right);
left_matching_in_right.is_some() && right_matching_in_left.is_some();

match (
has_bidirectional_matching_left_right,
Expand Down Expand Up @@ -213,7 +215,7 @@ mod tests {
use std::{borrow::Cow, vec};

use matching::{ordered, Matchings};
use model::{cst_node::NonTerminal, cst_node::Terminal, CSTNode, Language, Point};
use model::{cst_node::NonTerminal, cst_node::Terminal, CSTNode, Point};

use crate::{MergeError, MergedCSTNode};

Expand Down
8 changes: 5 additions & 3 deletions merge/src/unordered_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ pub fn unordered_merge<'a>(
}

let matching_base_left = base_left_matchings.find_matching_for(left_child);
let matching_left_right = left_right_matchings.find_matching_for(left_child);
let matching_left_right =
left_right_matchings.find_matching_node_in_children(left_child, right.get_children());

match (matching_base_left, matching_left_right) {
// Added only by left
Expand Down Expand Up @@ -91,7 +92,8 @@ pub fn unordered_merge<'a>(
.filter(|node| !processed_nodes.contains(&node.id()))
{
let matching_base_right = base_right_matchings.find_matching_for(right_child);
let matching_left_right = left_right_matchings.find_matching_for(right_child);
let matching_left_right =
left_right_matchings.find_matching_node_in_children(right_child, left.get_children());

match (matching_base_right, matching_left_right) {
// Added only by right
Expand Down Expand Up @@ -142,7 +144,7 @@ mod tests {
use matching::{unordered::calculate_matchings, Matchings};
use model::{
cst_node::{NonTerminal, Terminal},
CSTNode, Language, Point,
CSTNode, Point,
};

use crate::{MergeError, MergedCSTNode};
Expand Down

0 comments on commit 9ddddf8

Please sign in to comment.