diff --git a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs index 4e76d49299..ed3ebf3bbe 100644 --- a/src/query/phrase_prefix_query/phrase_prefix_scorer.rs +++ b/src/query/phrase_prefix_query/phrase_prefix_scorer.rs @@ -2,7 +2,7 @@ use crate::docset::{DocSet, TERMINATED}; use crate::fieldnorm::FieldNormReader; use crate::postings::Postings; use crate::query::bm25::Bm25Weight; -use crate::query::phrase_query::{intersection_count, PhraseScorer}; +use crate::query::phrase_query::{intersection_count, intersection_exists, PhraseScorer}; use crate::query::Scorer; use crate::{DocId, Score}; @@ -92,14 +92,17 @@ impl<TPostings: Postings> Scorer for PhraseKind<TPostings> { } } -pub struct PhrasePrefixScorer<TPostings: Postings> { +pub struct PhrasePrefixScorer<TPostings: Postings, const SCORING_ENABLED: bool> { phrase_scorer: PhraseKind<TPostings>, suffixes: Vec<TPostings>, suffix_offset: u32, phrase_count: u32, + suffix_position_buffer: Vec<u32>, } -impl<TPostings: Postings> PhrasePrefixScorer<TPostings> { +impl<TPostings: Postings, const SCORING_ENABLED: bool> + PhrasePrefixScorer<TPostings, SCORING_ENABLED> +{ // If similarity_weight is None, then scoring is disabled. pub fn new( mut term_postings: Vec<(usize, TPostings)>, @@ -107,7 +110,7 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> { fieldnorm_reader: FieldNormReader, suffixes: Vec<TPostings>, suffix_pos: usize, - ) -> PhrasePrefixScorer<TPostings> { + ) -> PhrasePrefixScorer<TPostings, SCORING_ENABLED> { // correct indices so we can merge with our suffix term the PhraseScorer doesn't know about let max_offset = term_postings .iter() @@ -140,6 +143,7 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> { suffixes, suffix_offset: (max_offset - suffix_pos) as u32, phrase_count: 0, + suffix_position_buffer: Vec::with_capacity(100), }; if phrase_prefix_scorer.doc() != TERMINATED && !phrase_prefix_scorer.matches_prefix() { phrase_prefix_scorer.advance(); @@ -153,7 +157,6 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> { fn matches_prefix(&mut self) -> bool { let mut count = 0; - let mut positions = Vec::new(); let current_doc = self.doc(); let pos_matching = self.phrase_scorer.get_intersection(); for suffix in &mut self.suffixes { @@ -162,16 +165,27 @@ impl<TPostings: Postings> PhrasePrefixScorer<TPostings> { } let doc = suffix.seek(current_doc); if doc == current_doc { - suffix.positions_with_offset(self.suffix_offset, &mut positions); - count += intersection_count(pos_matching, &positions); + suffix.positions_with_offset(self.suffix_offset, &mut self.suffix_position_buffer); + if SCORING_ENABLED { + count += intersection_count(pos_matching, &self.suffix_position_buffer); + } else { + if intersection_exists(pos_matching, &self.suffix_position_buffer) { + return true; + } + } } } + if !SCORING_ENABLED { + return false; + } self.phrase_count = count as u32; count != 0 } } -impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> { +impl<TPostings: Postings, const SCORING_ENABLED: bool> DocSet + for PhrasePrefixScorer<TPostings, SCORING_ENABLED> +{ fn advance(&mut self) -> DocId { loop { let doc = self.phrase_scorer.advance(); @@ -198,9 +212,15 @@ impl<TPostings: Postings> DocSet for PhrasePrefixScorer<TPostings> { } } -impl<TPostings: Postings> Scorer for PhrasePrefixScorer<TPostings> { +impl<TPostings: Postings, const SCORING_ENABLED: bool> Scorer + for PhrasePrefixScorer<TPostings, SCORING_ENABLED> +{ fn score(&mut self) -> Score { + if SCORING_ENABLED { + self.phrase_scorer.score() + } else { + 1.0f32 + } // TODO modify score?? - self.phrase_scorer.score() } } diff --git a/src/query/phrase_prefix_query/phrase_prefix_weight.rs b/src/query/phrase_prefix_query/phrase_prefix_weight.rs index 866c3c2c5a..22e4ae5f7a 100644 --- a/src/query/phrase_prefix_query/phrase_prefix_weight.rs +++ b/src/query/phrase_prefix_query/phrase_prefix_weight.rs @@ -42,11 +42,11 @@ impl PhrasePrefixWeight { Ok(FieldNormReader::constant(reader.max_doc(), 1)) } - pub(crate) fn phrase_scorer( + pub(crate) fn phrase_prefix_scorer<const SCORING_ENABLED: bool>( &self, reader: &SegmentReader, boost: Score, - ) -> crate::Result<Option<PhrasePrefixScorer<SegmentPostings>>> { + ) -> crate::Result<Option<PhrasePrefixScorer<SegmentPostings, SCORING_ENABLED>>> { let similarity_weight_opt = self .similarity_weight_opt .as_ref() @@ -128,15 +128,20 @@ impl PhrasePrefixWeight { impl Weight for PhrasePrefixWeight { fn scorer(&self, reader: &SegmentReader, boost: Score) -> crate::Result<Box<dyn Scorer>> { - if let Some(scorer) = self.phrase_scorer(reader, boost)? { - Ok(Box::new(scorer)) + if self.similarity_weight_opt.is_some() { + if let Some(scorer) = self.phrase_prefix_scorer::<true>(reader, boost)? { + return Ok(Box::new(scorer)); + } } else { - Ok(Box::new(EmptyScorer)) + if let Some(scorer) = self.phrase_prefix_scorer::<false>(reader, boost)? { + return Ok(Box::new(scorer)); + } } + Ok(Box::new(EmptyScorer)) } fn explain(&self, reader: &SegmentReader, doc: DocId) -> crate::Result<Explanation> { - let scorer_opt = self.phrase_scorer(reader, 1.0)?; + let scorer_opt = self.phrase_prefix_scorer::<true>(reader, 1.0)?; if scorer_opt.is_none() { return Err(does_not_match(doc)); } @@ -200,7 +205,7 @@ mod tests { .unwrap() .unwrap(); let mut phrase_scorer = phrase_weight - .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .phrase_prefix_scorer::<true>(searcher.segment_reader(0u32), 1.0)? .unwrap(); assert_eq!(phrase_scorer.doc(), 1); assert_eq!(phrase_scorer.phrase_count(), 2); @@ -211,6 +216,38 @@ mod tests { Ok(()) } + #[test] + pub fn test_phrase_no_count() -> crate::Result<()> { + let index = create_index(&[ + "aa bb dd cc", + "aa aa bb c dd aa bb cc aa bb dc", + " aa bb cd", + ])?; + let schema = index.schema(); + let text_field = schema.get_field("text").unwrap(); + let searcher = index.reader()?.searcher(); + let phrase_query = PhrasePrefixQuery::new(vec![ + Term::from_field_text(text_field, "aa"), + Term::from_field_text(text_field, "bb"), + Term::from_field_text(text_field, "c"), + ]); + let enable_scoring = EnableScoring::enabled_from_searcher(&searcher); + let phrase_weight = phrase_query + .phrase_prefix_query_weight(enable_scoring) + .unwrap() + .unwrap(); + let mut phrase_scorer = phrase_weight + .phrase_prefix_scorer::<false>(searcher.segment_reader(0u32), 1.0)? + .unwrap(); + assert_eq!(phrase_scorer.doc(), 1); + assert_eq!(phrase_scorer.phrase_count(), 0); + assert_eq!(phrase_scorer.advance(), 2); + assert_eq!(phrase_scorer.doc(), 2); + assert_eq!(phrase_scorer.phrase_count(), 0); + assert_eq!(phrase_scorer.advance(), TERMINATED); + Ok(()) + } + #[test] pub fn test_phrase_count_mid() -> crate::Result<()> { let index = create_index(&["aa dd cc", "aa aa bb c dd aa bb cc aa dc", " aa bb cd"])?; @@ -227,7 +264,7 @@ mod tests { .unwrap() .unwrap(); let mut phrase_scorer = phrase_weight - .phrase_scorer(searcher.segment_reader(0u32), 1.0)? + .phrase_prefix_scorer::<true>(searcher.segment_reader(0u32), 1.0)? .unwrap(); assert_eq!(phrase_scorer.doc(), 1); assert_eq!(phrase_scorer.phrase_count(), 2); diff --git a/src/query/phrase_query/mod.rs b/src/query/phrase_query/mod.rs index 7b8d3e0074..89b22cef1a 100644 --- a/src/query/phrase_query/mod.rs +++ b/src/query/phrase_query/mod.rs @@ -3,8 +3,8 @@ mod phrase_scorer; mod phrase_weight; pub use self::phrase_query::PhraseQuery; -pub(crate) use self::phrase_scorer::intersection_count; pub use self::phrase_scorer::PhraseScorer; +pub(crate) use self::phrase_scorer::{intersection_count, intersection_exists}; pub use self::phrase_weight::PhraseWeight; #[cfg(test)] diff --git a/src/query/phrase_query/phrase_scorer.rs b/src/query/phrase_query/phrase_scorer.rs index 147aef29b3..a8c4e4babd 100644 --- a/src/query/phrase_query/phrase_scorer.rs +++ b/src/query/phrase_query/phrase_scorer.rs @@ -58,7 +58,7 @@ pub struct PhraseScorer<TPostings: Postings> { } /// Returns true if and only if the two sorted arrays contain a common element -fn intersection_exists(left: &[u32], right: &[u32]) -> bool { +pub(crate) fn intersection_exists(left: &[u32], right: &[u32]) -> bool { let mut left_index = 0; let mut right_index = 0; while left_index < left.len() && right_index < right.len() {