Skip to content

Commit

Permalink
refactor(matching,merge): Use ref to store CSTNode in the hashmap key
Browse files Browse the repository at this point in the history
  • Loading branch information
jpedroh committed Oct 28, 2023
1 parent ea34157 commit df47e45
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 42 deletions.
5 changes: 4 additions & 1 deletion matching/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ pub use matching_entry::MatchingEntry;
pub use matchings::Matchings;
pub use ordered_tree_matching::ordered_tree_matching;

pub fn calculate_matchings<'a>(left: &'a model::CSTNode, right: &'a model::CSTNode) -> Matchings<'a> {
pub fn calculate_matchings<'a>(
left: &'a model::CSTNode,
right: &'a model::CSTNode,
) -> Matchings<'a> {
return ordered_tree_matching::ordered_tree_matching(left, right);
}
21 changes: 9 additions & 12 deletions matching/src/matchings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::matching_entry::MatchingEntry;

#[derive(Debug, Clone)]
pub struct Matchings<'a> {
matching_entries: HashMap<UnorderedPair<CSTNode<'a>>, MatchingEntry>,
matching_entries: HashMap<UnorderedPair<&'a CSTNode<'a>>, MatchingEntry>,
}

impl<'a> Matchings<'a> {
Expand All @@ -18,16 +18,16 @@ impl<'a> Matchings<'a> {
}
}

pub fn new(matching_entries: HashMap<UnorderedPair<CSTNode<'a>>, MatchingEntry>) -> Self {
pub fn new(matching_entries: HashMap<UnorderedPair<&'a CSTNode<'a>>, MatchingEntry>) -> Self {
Matchings { matching_entries }
}

pub fn find_matching_for(&self, a_node: &'a CSTNode) -> Option<Matching> {
self.matching_entries
.iter()
.find(|(UnorderedPair(left, right), ..)| left == a_node || right == a_node)
.find(|(UnorderedPair(left, right), ..)| left == &a_node || right == &a_node)
.map(|(UnorderedPair(left, right), matching)| {
let matching_node = if left == a_node { right } else { left };
let matching_node = if left == &a_node { right } else { left };
Matching {
matching_node,
score: matching.score,
Expand All @@ -38,8 +38,8 @@ impl<'a> Matchings<'a> {

pub fn get_matching_entry(
&'a self,
left: CSTNode<'a>,
right: CSTNode<'a>,
left: &'a CSTNode<'a>,
right: &'a CSTNode<'a>,
) -> Option<&MatchingEntry> {
self.matching_entries.get(&UnorderedPair(left, right))
}
Expand All @@ -52,7 +52,7 @@ mod tests {
#[test]
fn returns_none_if_a_matching_for_the_node_is_not_found() {
let a_node = CSTNode::Terminal {
kind: "kind".into(),
kind: "kind",
value: "value".into(),
};

Expand All @@ -62,15 +62,12 @@ mod tests {
#[test]
fn returns_some_match_if_a_matching_for_the_node_is_found() {
let a_node = CSTNode::Terminal {
kind: "kind".into(),
kind: "kind",
value: "value".into(),
};

let mut matchings = HashMap::new();
matchings.insert(
UnorderedPair(a_node.clone(), a_node.clone()),
MatchingEntry::new(1, true),
);
matchings.insert(UnorderedPair(&a_node, &a_node), MatchingEntry::new(1, true));

assert_eq!(
Some(Matching {
Expand Down
45 changes: 20 additions & 25 deletions matching/src/ordered_tree_matching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ enum Direction {
#[derive(Clone)]
struct Entry<'a>(
pub Direction,
pub HashMap<UnorderedPair<CSTNode<'a>>, MatchingEntry>,
pub HashMap<UnorderedPair<&'a CSTNode<'a>>, MatchingEntry>,
);

impl<'a> Default for Entry<'a> {
Expand All @@ -29,7 +29,7 @@ pub fn ordered_tree_matching<'a>(left: &'a CSTNode, right: &'a CSTNode) -> Match
fn ordered_tree_matching_helper<'a>(
left: &'a CSTNode,
right: &'a CSTNode,
) -> HashMap<UnorderedPair<CSTNode<'a>>, MatchingEntry> {
) -> HashMap<UnorderedPair<&'a CSTNode<'a>>, MatchingEntry> {
match (left, right) {
(
CSTNode::NonTerminal {
Expand All @@ -55,12 +55,7 @@ fn ordered_tree_matching_helper<'a>(
let right_child = children_right.get(j - 1).unwrap();

let w = ordered_tree_matching_helper(left_child, right_child);
let matching = w
.get(&UnorderedPair::new(
left_child.to_owned(),
right_child.to_owned(),
))
.unwrap();
let matching = w.get(&UnorderedPair::new(left_child, right_child)).unwrap();

if matrix_m[i][j - 1] > matrix_m[i - 1][j] {
if matrix_m[i][j - 1] > matrix_m[i - 1][j - 1] + matching.score {
Expand All @@ -84,7 +79,7 @@ fn ordered_tree_matching_helper<'a>(

let mut i = m;
let mut j = n;
let mut children = Vec::<&HashMap<UnorderedPair<CSTNode>, MatchingEntry>>::new();
let mut children = Vec::<&HashMap<UnorderedPair<&'a CSTNode>, MatchingEntry>>::new();

while i >= 1 && j >= 1 {
match matrix_t.get(i).unwrap().get(j).unwrap().0 {
Expand All @@ -102,10 +97,7 @@ fn ordered_tree_matching_helper<'a>(

let matching = MatchingEntry::new(matrix_m[m][n] + root_matching, left == right);
let mut result = HashMap::new();
result.insert(
UnorderedPair::new(left.to_owned(), right.to_owned()),
matching,
);
result.insert(UnorderedPair::new(left, right), matching);
children.into_iter().for_each(|child_matchings| {
child_matchings.iter().for_each(|(key, matching)| {
result.insert(key.to_owned(), matching.to_owned());
Expand All @@ -126,15 +118,15 @@ fn ordered_tree_matching_helper<'a>(
let mut result = HashMap::new();
let is_perfetch_match = kind_left == kind_right && value_left == value_right;
result.insert(
UnorderedPair::new(left.to_owned(), right.to_owned()),
UnorderedPair::new(left, right),
MatchingEntry::new(is_perfetch_match.into(), is_perfetch_match),
);
result
}
(_, _) => {
let mut result = HashMap::new();
result.insert(
UnorderedPair::new(left.to_owned(), right.to_owned()),
UnorderedPair::new(left, right),
MatchingEntry::new(0, false),
);
result
Expand Down Expand Up @@ -162,7 +154,7 @@ mod tests {

assert_eq!(
Some(&MatchingEntry::new(1, true)),
matchings.get_matching_entry(left.clone(), right.clone())
matchings.get_matching_entry(&left, &right)
)
}

Expand All @@ -181,7 +173,7 @@ mod tests {

assert_eq!(
Some(&MatchingEntry::new(0, false)),
matchings.get_matching_entry(left.clone(), right.clone())
matchings.get_matching_entry(&left, &right)
)
}

Expand All @@ -200,7 +192,7 @@ mod tests {

assert_eq!(
Some(&MatchingEntry::new(0, false)),
matchings.get_matching_entry(left.clone(), right.clone())
matchings.get_matching_entry(&left, &right)
)
}

Expand All @@ -219,7 +211,7 @@ mod tests {

assert_eq!(
Some(&MatchingEntry::new(0, false)),
matchings.get_matching_entry(left.clone(), right.clone())
matchings.get_matching_entry(&left, &right)
)
}

Expand All @@ -242,7 +234,7 @@ mod tests {

assert_eq!(
Some(&MatchingEntry::new(1, true)),
matchings.get_matching_entry(child.clone(), child)
matchings.get_matching_entry(&child, &child)
)
}

Expand All @@ -268,7 +260,10 @@ mod tests {

let matchings = ordered_tree_matching(&left, &right);

assert_eq!(None, matchings.get_matching_entry(left_child, right_child))
assert_eq!(
None,
matchings.get_matching_entry(&left_child, &right_child)
)
}

#[test]
Expand All @@ -295,7 +290,7 @@ mod tests {

assert_eq!(
Some(&MatchingEntry::new(2, false)),
matchings.get_matching_entry(left.clone(), right.clone())
matchings.get_matching_entry(&left, &right)
)
}

Expand All @@ -319,7 +314,7 @@ mod tests {

assert_eq!(
Some(&MatchingEntry::new(2, true)),
matchings.get_matching_entry(left.clone(), right.clone())
matchings.get_matching_entry(&left, &right)
)
}

Expand Down Expand Up @@ -348,12 +343,12 @@ mod tests {

assert_eq!(
Some(&MatchingEntry::new(2, true)),
matchings.get_matching_entry(intermediate.clone(), intermediate)
matchings.get_matching_entry(&intermediate, &intermediate)
);

assert_eq!(
Some(&MatchingEntry::new(3, true)),
matchings.get_matching_entry(left.clone(), right.clone())
matchings.get_matching_entry(&left, &right)
)
}
}
6 changes: 2 additions & 4 deletions merge/src/odered_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@ pub fn ordered_merge<'a>(
let matching_base_left = base_left_matchings.find_matching_for(cur_left.unwrap());
let matching_base_right =
base_right_matchings.find_matching_for(cur_right.unwrap());
let bidirectional_matching_left_right = left_right_matchings.get_matching_entry(
cur_left.unwrap().to_owned(),
cur_right.unwrap().to_owned(),
);
let bidirectional_matching_left_right =
left_right_matchings.get_matching_entry(cur_left.unwrap(), cur_right.unwrap());
let left_matching_in_right =
left_right_matchings.find_matching_for(cur_left.unwrap());
let right_matching_in_left =
Expand Down

0 comments on commit df47e45

Please sign in to comment.