From 4f51aeabdd178e0fcafef90278a2d1db51076636 Mon Sep 17 00:00:00 2001 From: Michael Vlach Date: Sun, 9 Jul 2023 09:22:44 +0200 Subject: [PATCH] [db] Fix edge_count condition stopping the search incorrectly #636 (#637) * Update search_where_test.rs * Update search_where_test.rs * Update query_condition.rs * Update db.rs --- src/agdb/db.rs | 32 ++++++++++---------- src/agdb/query/query_condition.rs | 49 +++++++++++++++++++------------ tests/search_where_test.rs | 40 +++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 35 deletions(-) diff --git a/src/agdb/db.rs b/src/agdb/db.rs index 98e7a0a8..15eabbd1 100644 --- a/src/agdb/db.rs +++ b/src/agdb/db.rs @@ -762,29 +762,29 @@ impl Db { condition: &QueryConditionData, ) -> Result { match condition { - QueryConditionData::Distance(value) => Ok(value.compare(distance)), + QueryConditionData::Distance(value) => Ok(value.compare_distance(distance)), QueryConditionData::Edge => Ok(SearchControl::Continue(index.is_edge())), - QueryConditionData::EdgeCount(value) => { - Ok(if let Some(node) = self.graph.node(index) { + QueryConditionData::EdgeCount(value) => Ok(SearchControl::Continue( + if let Some(node) = self.graph.node(index) { value.compare(node.edge_count()) } else { - SearchControl::Continue(false) - }) - } - QueryConditionData::EdgeCountFrom(value) => { - Ok(if let Some(node) = self.graph.node(index) { + false + }, + )), + QueryConditionData::EdgeCountFrom(value) => Ok(SearchControl::Continue( + if let Some(node) = self.graph.node(index) { value.compare(node.edge_count_from()) } else { - SearchControl::Continue(false) - }) - } - QueryConditionData::EdgeCountTo(value) => { - Ok(if let Some(node) = self.graph.node(index) { + false + }, + )), + QueryConditionData::EdgeCountTo(value) => Ok(SearchControl::Continue( + if let Some(node) = self.graph.node(index) { value.compare(node.edge_count_to()) } else { - SearchControl::Continue(false) - }) - } + false + }, + )), QueryConditionData::Ids(values) => { Ok(SearchControl::Continue(values.iter().any(|id| { index.0 diff --git a/src/agdb/query/query_condition.rs b/src/agdb/query/query_condition.rs index 09712a2b..4512b362 100644 --- a/src/agdb/query/query_condition.rs +++ b/src/agdb/query/query_condition.rs @@ -144,7 +144,7 @@ pub enum Comparison { } impl CountComparison { - pub(crate) fn compare(&self, right: u64) -> SearchControl { + pub(crate) fn compare_distance(&self, right: u64) -> SearchControl { match self { CountComparison::Equal(left) => match right.cmp(left) { std::cmp::Ordering::Less => SearchControl::Continue(false), @@ -182,6 +182,17 @@ impl CountComparison { }, } } + + pub(crate) fn compare(&self, left: u64) -> bool { + match self { + CountComparison::Equal(right) => left == *right, + CountComparison::GreaterThan(right) => left > *right, + CountComparison::GreaterThanOrEqual(right) => left >= *right, + CountComparison::LessThan(right) => left < *right, + CountComparison::LessThanOrEqual(right) => left <= *right, + CountComparison::NotEqual(right) => left != *right, + } + } } impl Comparison { @@ -271,23 +282,23 @@ mod tests { use SearchControl::Continue; use SearchControl::Stop; - assert_eq!(Equal(2).compare(3), Stop(false)); - assert_eq!(Equal(2).compare(2), Stop(true)); - assert_eq!(Equal(2).compare(1), Continue(false)); - assert_eq!(NotEqual(2).compare(3), Continue(true)); - assert_eq!(NotEqual(2).compare(2), Continue(false)); - assert_eq!(NotEqual(2).compare(1), Continue(true)); - assert_eq!(GreaterThan(2).compare(3), Continue(true)); - assert_eq!(GreaterThan(2).compare(2), Continue(false)); - assert_eq!(GreaterThan(2).compare(1), Continue(false)); - assert_eq!(GreaterThanOrEqual(2).compare(3), Continue(true)); - assert_eq!(GreaterThanOrEqual(2).compare(2), Continue(true)); - assert_eq!(GreaterThanOrEqual(2).compare(1), Continue(false)); - assert_eq!(LessThan(2).compare(3), Stop(false)); - assert_eq!(LessThan(2).compare(2), Stop(false)); - assert_eq!(LessThan(2).compare(1), Continue(true)); - assert_eq!(LessThanOrEqual(2).compare(3), Stop(false)); - assert_eq!(LessThanOrEqual(2).compare(2), Stop(true)); - assert_eq!(LessThanOrEqual(2).compare(1), Continue(true)); + assert_eq!(Equal(2).compare_distance(3), Stop(false)); + assert_eq!(Equal(2).compare_distance(2), Stop(true)); + assert_eq!(Equal(2).compare_distance(1), Continue(false)); + assert_eq!(NotEqual(2).compare_distance(3), Continue(true)); + assert_eq!(NotEqual(2).compare_distance(2), Continue(false)); + assert_eq!(NotEqual(2).compare_distance(1), Continue(true)); + assert_eq!(GreaterThan(2).compare_distance(3), Continue(true)); + assert_eq!(GreaterThan(2).compare_distance(2), Continue(false)); + assert_eq!(GreaterThan(2).compare_distance(1), Continue(false)); + assert_eq!(GreaterThanOrEqual(2).compare_distance(3), Continue(true)); + assert_eq!(GreaterThanOrEqual(2).compare_distance(2), Continue(true)); + assert_eq!(GreaterThanOrEqual(2).compare_distance(1), Continue(false)); + assert_eq!(LessThan(2).compare_distance(3), Stop(false)); + assert_eq!(LessThan(2).compare_distance(2), Stop(false)); + assert_eq!(LessThan(2).compare_distance(1), Continue(true)); + assert_eq!(LessThanOrEqual(2).compare_distance(3), Stop(false)); + assert_eq!(LessThanOrEqual(2).compare_distance(2), Stop(true)); + assert_eq!(LessThanOrEqual(2).compare_distance(1), Continue(true)); } } diff --git a/tests/search_where_test.rs b/tests/search_where_test.rs index b48e6558..64dbae07 100644 --- a/tests/search_where_test.rs +++ b/tests/search_where_test.rs @@ -231,6 +231,46 @@ fn search_from_where_edge_count_test() { .query(), &[8, 7, 6], ); + db.exec_ids( + QueryBuilder::search() + .from("root") + .where_() + .edge_count(CountComparison::GreaterThanOrEqual(2)) + .query(), + &[1, 3, 2, 8, 7, 6, 15, 14, 12], + ); + db.exec_ids( + QueryBuilder::search() + .from("root") + .where_() + .edge_count(CountComparison::LessThanOrEqual(2)) + .query(), + &[1, 8, 7, 6, 16, 15, 14, 13, 12], + ); + db.exec_ids( + QueryBuilder::search() + .from("root") + .where_() + .edge_count(CountComparison::LessThan(2)) + .query(), + &[16, 13], + ); + db.exec_ids( + QueryBuilder::search() + .from("root") + .where_() + .edge_count(CountComparison::Equal(2)) + .query(), + &[1, 8, 7, 6, 15, 14, 12], + ); + db.exec_ids( + QueryBuilder::search() + .from("root") + .where_() + .edge_count(CountComparison::NotEqual(2)) + .query(), + &[3, 2, 16, 13], + ); } #[test]