Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding facet filter and search #21

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 111 additions & 4 deletions src/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use crate::query::Query;
use crate::{get_field, to_pyerr};
use pyo3::exceptions::ValueError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3::PyObjectProtocol;
use std::collections::BTreeMap;
use tantivy as tv;
use tantivy::collector::{Count, MultiCollector, TopDocs};

Expand Down Expand Up @@ -45,6 +47,7 @@ impl ToPyObject for Fruit {
/// Object holding a results successful search.
pub(crate) struct SearchResult {
hits: Vec<(Fruit, DocAddress)>,
facets_result: BTreeMap<String, Vec<(String, u64)>>,
#[pyo3(get)]
/// How many documents matched the query. Only available if `count` was set
/// to true during the search.
Expand All @@ -56,11 +59,17 @@ impl PyObjectProtocol for SearchResult {
fn __repr__(&self) -> PyResult<String> {
if let Some(count) = self.count {
Ok(format!(
"SearchResult(hits: {:?}, count: {})",
self.hits, count
"SearchResult(hits: {:?}, count: {}, facets: {})",
self.hits,
count,
self.facets_result.len()
))
} else {
Ok(format!("SearchResult(hits: {:?})", self.hits))
Ok(format!(
"SearchResult(hits: {:?}, facets: {})",
self.hits,
self.facets_result.len()
))
}
}
}
Expand All @@ -78,6 +87,14 @@ impl SearchResult {
.collect();
Ok(ret)
}

#[getter]
bloodbare marked this conversation as resolved.
Show resolved Hide resolved
fn facets(
&self,
_py: Python,
) -> PyResult<BTreeMap<String, Vec<(String, u64)>>> {
Ok(self.facets_result.clone())
}
}

#[pymethods]
Expand All @@ -94,6 +111,8 @@ impl Searcher {
/// should be ordered by. The field must be declared as a fast field
/// when building the schema. Note, this only works for unsigned
/// fields.
/// facets (PyDict, optional): A dictionary of facet fields and keys to
/// filter.
///
/// Returns `SearchResult` object.
///
Expand All @@ -106,6 +125,7 @@ impl Searcher {
limit: usize,
count: bool,
order_by_field: Option<&str>,
facets: Option<&PyDict>,
) -> PyResult<SearchResult> {
let mut multicollector = MultiCollector::new();

Expand All @@ -115,6 +135,37 @@ impl Searcher {
None
};

let mut facets_requests = BTreeMap::new();

// We create facets collector for each field and terms defined on the facets args
if let Some(facets_dict) = facets {
for key_value_any in facets_dict.items() {
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
if key_value.len() != 2 {
continue;
}
let key: String = key_value.get_item(0).extract()?;
let field = get_field(&self.inner.index().schema(), &key)?;

let mut facet_collector =
tv::collector::FacetCollector::for_field(field);

if let Ok(value_list) =
key_value.get_item(1).downcast::<PyList>()
{
for value_element in value_list {
if let Ok(s) = value_element.extract::<String>() {
facet_collector.add_facet(&s);
}
}
let facet_handler =
multicollector.add_collector(facet_collector);
facets_requests.insert(key, facet_handler);
}
}
}
}

let (mut multifruit, hits) = {
if let Some(order_by) = order_by_field {
let field = get_field(&self.inner.index().schema(), order_by)?;
Expand Down Expand Up @@ -162,7 +213,52 @@ impl Searcher {
None => None,
};

Ok(SearchResult { hits, count })
let mut facets_result: BTreeMap<String, Vec<(String, u64)>> =
BTreeMap::new();

// Go though all collectors that are registered
for (key, facet_collector) in facets_requests {
let facet_count = facet_collector.extract(&mut multifruit);
let mut facet_vec = Vec::new();
if let Some(facets_dict) = facets {
match facets_dict.get_item(key.clone()) {
Some(facets_list_by_key) => {
if let Ok(facets_list_by_key_native) =
facets_list_by_key.downcast::<PyList>()
{
for facet_value in facets_list_by_key_native {
if let Ok(s) = facet_value.extract::<String>() {
let facet_value_vec: Vec<(
&tv::schema::Facet,
u64,
)> = facet_count.get(&s).collect();

// Go for all elements on facet and count to add on vector
for (
facet_value_vec_element,
facet_count,
) in facet_value_vec
{
facet_vec.push((
facet_value_vec_element.to_string(),
facet_count,
))
}
}
}
}
}
None => println!("Not found."),
}
}
facets_result.insert(key.clone(), facet_vec);
}

Ok(SearchResult {
hits,
count,
facets_result,
})
}

/// Returns the overall number of documents in the index.
Expand All @@ -171,6 +267,17 @@ impl Searcher {
self.inner.num_docs()
}

fn docn(&self, seg_doc: &PyTuple) -> PyResult<Document> {
bloodbare marked this conversation as resolved.
Show resolved Hide resolved
let seg: u32 = seg_doc.get_item(0).extract()?;
let doc: u32 = seg_doc.get_item(1).extract()?;
let address = tv::DocAddress(seg, doc);
let doc = self.inner.doc(address).map_err(to_pyerr)?;
let named_doc = self.inner.schema().to_named_doc(&doc);
Ok(Document {
field_values: named_doc.0,
})
}

/// Fetches a document from Tantivy's store given a DocAddress.
///
/// Args:
Expand Down
33 changes: 30 additions & 3 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def schema():
return SchemaBuilder().add_text_field("title", stored=True).add_text_field("body").build()
return SchemaBuilder().add_text_field("title", stored=True).add_text_field("body").add_facet_field("facet").build()

def create_index(dir=None):
# assume all tests will use the same documents for now
Expand All @@ -27,6 +27,7 @@ def create_index(dir=None):
"now without taking a fish."
),
)
doc.add_facet('facet', tantivy.Facet.from_string("/mytag"))
writer.add_document(doc)
# 2 use the built-in json support
# keys need to coincide with field names
Expand Down Expand Up @@ -117,14 +118,40 @@ def test_and_query_parser_default_fields(self, ram_index):
assert repr(query) == """Query(TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114])))"""

def test_and_query_parser_default_fields_undefined(self, ram_index):
query = ram_index.parse_query("winter")
query = ram_index.parse_query("/winter")
assert (
repr(query) == "Query(BooleanQuery { subqueries: ["
"(Should, TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114]))), "
"(Should, TermQuery(Term(field=1,bytes=[119, 105, 110, 116, 101, 114])))] "
"(Should, TermQuery(Term(field=1,bytes=[119, 105, 110, 116, 101, 114]))), "
"(Should, TermQuery(Term(field=2,bytes=[119, 105, 110, 116, 101, 114])))] "
"})"
)

def test_and_query_parser_default_fields_facets(self, ram_index):
index = ram_index
query = index.parse_query("old +facet:/mytag", default_field_names=["title", "body"])
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10)
assert result.count == 1

query = index.parse_query("old +facet:/wrong", default_field_names=["title", "body"])
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10)
assert result.count == 0

def test_search_facets(self, ram_index):
index = ram_index
query = index.parse_query("old", default_field_names=["title", "body"])
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10, facets={"facet": ["/"]})
assert result.count == 1
assert ('/mytag', 1) in result.facets['facet']
bloodbare marked this conversation as resolved.
Show resolved Hide resolved


bloodbare marked this conversation as resolved.
Show resolved Hide resolved

def test_query_errors(self, ram_index):
index = ram_index
# no "bod" field
Expand Down