Skip to content

Commit

Permalink
Add field_boosts and fuzzy_fields optional parameters to Index::parse…
Browse files Browse the repository at this point in the history
…_query to expose this QueryParser functionality.
  • Loading branch information
adamreichold committed Feb 5, 2024
1 parent 472bad8 commit 806ce03
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 64 deletions.
161 changes: 97 additions & 64 deletions src/index.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(clippy::new_ret_no_self)]

use std::collections::HashMap;

use pyo3::{exceptions, prelude::*, types::PyAny};

use crate::{
Expand Down Expand Up @@ -358,44 +360,33 @@ impl Index {
///
/// Args:
/// query: the query, following the tantivy query language.
///
/// default_fields_names (List[Field]): A list of fields used to search if no
/// field is specified in the query.
///
#[pyo3(signature = (query, default_field_names = None))]
/// field_boosts: A dictionary keyed on field names which provides default boosts
/// for the query constructed by this method.
///
/// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one)
/// triples making queries constructed by this method fuzzy against the given fields
/// and using the given parameters.
/// `prefix` determines if terms which are prefixes of the given term match the query.
/// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term.
/// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance.
#[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))]
pub fn parse_query(
&self,
query: &str,
default_field_names: Option<Vec<String>>,
field_boosts: HashMap<String, tv::Score>,
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
) -> PyResult<Query> {
let mut default_fields = vec![];
let schema = self.index.schema();
if let Some(default_field_names_vec) = default_field_names {
for default_field_name in &default_field_names_vec {
if let Ok(field) = schema.get_field(default_field_name) {
let field_entry = schema.get_field_entry(field);
if !field_entry.is_indexed() {
return Err(exceptions::PyValueError::new_err(
format!(
"Field `{default_field_name}` is not set as indexed in the schema."
),
));
}
default_fields.push(field);
} else {
return Err(exceptions::PyValueError::new_err(format!(
"Field `{default_field_name}` is not defined in the schema."
)));
}
}
} else {
for (field, field_entry) in self.index.schema().fields() {
if field_entry.is_indexed() {
default_fields.push(field);
}
}
}
let parser =
tv::query::QueryParser::for_index(&self.index, default_fields);
let parser = self.prepare_query_parser(
default_field_names,
field_boosts,
fuzzy_fields,
)?;

let query = parser.parse_query(query).map_err(to_pyerr)?;

Ok(Query { inner: query })
Expand All @@ -410,64 +401,106 @@ impl Index {
///
/// Args:
/// query: the query, following the tantivy query language.
///
/// default_fields_names (List[Field]): A list of fields used to search if no
/// field is specified in the query.
///
/// field_boosts: A dictionary keyed on field names which provides default boosts
/// for the query constructed by this method.
///
/// fuzzy_fields: A dictionary keyed on field names which provides (prefix, distance, transpose_cost_one)
/// triples making queries constructed by this method fuzzy against the given fields
/// and using the given parameters.
/// `prefix` determines if terms which are prefixes of the given term match the query.
/// `distance` determines the maximum Levenshtein distance between terms matching the query and the given term.
/// `transpose_cost_one` determines if transpositions of neighbouring characters are counted only once against the Levenshtein distance.
///
/// Returns a tuple containing the parsed query and a list of errors.
///
/// Raises ValueError if a field in `default_field_names` is not defined or marked as indexed.
#[pyo3(signature = (query, default_field_names = None))]
#[pyo3(signature = (query, default_field_names = None, field_boosts = HashMap::new(), fuzzy_fields = HashMap::new()))]
pub fn parse_query_lenient(
&self,
query: &str,
default_field_names: Option<Vec<String>>,
field_boosts: HashMap<String, tv::Score>,
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
py: Python,
) -> PyResult<(Query, Vec<PyObject>)> {
let parser = self.prepare_query_parser(
default_field_names,
field_boosts,
fuzzy_fields,
)?;

let (query, errors) = parser.parse_query_lenient(query);
let errors = errors.into_iter().map(|err| err.into_py(py)).collect();

Ok((Query { inner: query }, errors))
}
}

impl Index {
fn prepare_query_parser(
&self,
default_field_names: Option<Vec<String>>,
field_boosts: HashMap<String, tv::Score>,
fuzzy_fields: HashMap<String, (bool, u8, bool)>,
) -> PyResult<tv::query::QueryParser> {
let schema = self.index.schema();

let default_fields = if let Some(default_field_names_vec) =
let default_fields = if let Some(default_field_names) =
default_field_names
{
default_field_names_vec
.iter()
.map(|field_name| {
schema
.get_field(field_name)
.map_err(|_err| {
exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is not defined in the schema."
))
})
.and_then(|field| {
schema.get_field_entry(field).is_indexed().then_some(field).ok_or(
exceptions::PyValueError::new_err(
format!(
"Field `{field_name}` is not set as indexed in the schema."
),
))
})
}).collect::<Result<Vec<_>, _>>()?
default_field_names.iter().map(|field_name| {
let field = schema.get_field(field_name).map_err(|_err| {
exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is not defined in the schema."
))
})?;

let field_entry = schema.get_field_entry(field);
if !field_entry.is_indexed() {
return Err(exceptions::PyValueError::new_err(
format!("Field `{field_name}` is not set as indexed in the schema.")
));
}

Ok(field)
}).collect::<PyResult<_>>()?
} else {
self.index
.schema()
schema
.fields()
.filter_map(|(f, fe)| fe.is_indexed().then_some(f))
.collect::<Vec<_>>()
.filter(|(_, field_entry)| field_entry.is_indexed())
.map(|(field, _)| field)
.collect()
};

let parser =
let mut parser =
tv::query::QueryParser::for_index(&self.index, default_fields);
let (query, errors) = parser.parse_query_lenient(query);

Python::with_gil(|py| {
let errors =
errors.into_iter().map(|err| err.into_py(py)).collect();
for (field_name, boost) in field_boosts {
let field = schema.get_field(&field_name).map_err(|_err| {
exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is not defined in the schema."
))
})?;
parser.set_field_boost(field, boost);
}

Ok((Query { inner: query }, errors))
})
for (field_name, (prefix, distance, transpose_cost_one)) in fuzzy_fields
{
let field = schema.get_field(&field_name).map_err(|_err| {
exceptions::PyValueError::new_err(format!(
"Field `{field_name}` is not defined in the schema."
))
})?;
parser.set_field_fuzzy(field, prefix, distance, transpose_cost_one);
}

Ok(parser)
}
}

impl Index {
fn register_custom_text_analyzers(index: &tv::Index) {
let analyzers = [
("ar_stem", Language::Arabic),
Expand Down
14 changes: 14 additions & 0 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ def test_and_query_parser_default_fields_undefined(self, ram_index):
== """Query(BooleanQuery { subqueries: [(Should, TermQuery(Term(field=0, type=Str, "winter"))), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
)

def test_parse_query_field_boosts(self, ram_index):
query = ram_index.parse_query("winter", field_boosts={"title": 2.3})
assert (
repr(query)
== """Query(BooleanQuery { subqueries: [(Should, Boost(query=TermQuery(Term(field=0, type=Str, "winter")), boost=2.3)), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
)

def test_parse_query_field_boosts(self, ram_index):
query = ram_index.parse_query("winter", fuzzy_fields={"title": (True, 1, False)})
assert (
repr(query)
== """Query(BooleanQuery { subqueries: [(Should, FuzzyTermQuery { term: Term(field=0, type=Str, "winter"), distance: 1, transposition_cost_one: false, prefix: true }), (Should, TermQuery(Term(field=1, type=Str, "winter")))] })"""
)

def test_query_errors(self, ram_index):
index = ram_index
# no "bod" field
Expand Down

0 comments on commit 806ce03

Please sign in to comment.