Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding facet filter and search #21

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ __pycache__/
tantivy.so
tantivy/tantivy.cpython*.so
tantivy.egg-info/
.python-version
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tantivy"
version = "0.13.2"
version = "0.14.0"
Copy link
Contributor

@poljar poljar Apr 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the tantivy upgrade should be a separate PR, or is there something this PR needs from 0.14?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using this branch on prod, I can do a new PR for the 0.14 if its better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please, separate PR for the 0.14 bump would be nice.

readme = "README.md"
authors = ["Damir Jelić <[email protected]>"]
edition = "2018"
Expand All @@ -12,7 +12,7 @@ crate-type = ["cdylib"]

[dependencies]
chrono = "0.4.19"
tantivy = "0.13.2"
tantivy = "0.14"
itertools = "0.9.0"
futures = "0.3.5"

Expand Down
2 changes: 1 addition & 1 deletion src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl Index {
#[staticmethod]
fn exists(path: &str) -> PyResult<bool> {
let directory = MmapDirectory::open(path).map_err(to_pyerr)?;
Ok(tv::Index::exists(&directory))
Ok(tv::Index::exists(&directory).unwrap())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should panic for IO errors.

}

/// The schema of the current index.
Expand Down
25 changes: 23 additions & 2 deletions src/schemabuilder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,18 @@ impl SchemaBuilder {
///
/// Args:
/// name (str): The name of the field.
fn add_bytes_field(&mut self, name: &str) -> PyResult<Self> {
fn add_bytes_field(
&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: bool,
) -> PyResult<Self> {
let builder = &mut self.builder;
let opts = SchemaBuilder::build_bytes_option(stored, indexed, fast)?;

if let Some(builder) = builder.write().unwrap().as_mut() {
builder.add_bytes_field(name);
builder.add_bytes_field(name, opts);
} else {
return Err(exceptions::PyValueError::new_err(
"Schema builder object isn't valid anymore.",
Expand Down Expand Up @@ -316,4 +323,18 @@ impl SchemaBuilder {

Ok(opts)
}

fn build_bytes_option(
stored: bool,
indexed: bool,
fast: bool,
) -> PyResult<schema::BytesOptions> {
let opts = schema::BytesOptions::default();

let opts = if stored { opts.set_stored() } else { opts };
let opts = if indexed { opts.set_indexed() } else { opts };
let opts = if fast { opts.set_fast() } else { opts };

Ok(opts)
}
}
106 changes: 102 additions & 4 deletions src/searcher.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#![allow(clippy::new_ret_no_self)]

use crate::{document::Document, get_field, query::Query, to_pyerr};
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3::{exceptions::PyValueError, prelude::*, PyObjectProtocol};
use std::collections::BTreeMap;
use tantivy as tv;
use tantivy::collector::{Count, MultiCollector, TopDocs};

Expand Down Expand Up @@ -41,6 +43,7 @@ impl ToPyObject for Fruit {
/// Object holding a results successful search.
pub(crate) struct SearchResult {
hits: Vec<(Fruit, DocAddress)>,
facets_result: BTreeMap<String, Vec<(String, u64)>>,
#[pyo3(get)]
/// How many documents matched the query. Only available if `count` was set
/// to true during the search.
Expand All @@ -52,11 +55,17 @@ impl PyObjectProtocol for SearchResult {
fn __repr__(&self) -> PyResult<String> {
if let Some(count) = self.count {
Ok(format!(
"SearchResult(hits: {:?}, count: {})",
self.hits, count
"SearchResult(hits: {:?}, count: {}, facets: {})",
self.hits,
count,
self.facets_result.len()
))
} else {
Ok(format!("SearchResult(hits: {:?})", self.hits))
Ok(format!(
"SearchResult(hits: {:?}, facets: {})",
self.hits,
self.facets_result.len()
))
}
}
}
Expand All @@ -74,6 +83,16 @@ impl SearchResult {
.collect();
Ok(ret)
}

#[getter]
bloodbare marked this conversation as resolved.
Show resolved Hide resolved
/// The list of facets that are requested on the search based on the
/// search results.
fn facets(
&self,
_py: Python,
) -> PyResult<BTreeMap<String, Vec<(String, u64)>>> {
Ok(self.facets_result.clone())
}
}

#[pymethods]
Expand All @@ -90,6 +109,8 @@ impl Searcher {
/// should be ordered by. The field must be declared as a fast field
/// when building the schema. Note, this only works for unsigned
/// fields.
/// facets (PyDict, optional): A dictionary of facet fields and keys to
/// filter.
/// offset (Field, optional): The offset from which the results have
/// to be returned.
///
Expand All @@ -104,6 +125,7 @@ impl Searcher {
limit: usize,
count: bool,
order_by_field: Option<&str>,
facets: Option<&PyDict>,
offset: usize,
) -> PyResult<SearchResult> {
let mut multicollector = MultiCollector::new();
Expand All @@ -114,6 +136,37 @@ impl Searcher {
None
};

let mut facets_requests = BTreeMap::new();

// We create facets collector for each field and terms defined on the facets args
if let Some(facets_dict) = facets {
for key_value_any in facets_dict.items() {
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
if key_value.len() != 2 {
continue;
}
let key: String = key_value.get_item(0).extract()?;
let field = get_field(&self.inner.index().schema(), &key)?;

let mut facet_collector =
tv::collector::FacetCollector::for_field(field);

if let Ok(value_list) =
key_value.get_item(1).downcast::<PyList>()
{
for value_element in value_list {
if let Ok(s) = value_element.extract::<String>() {
facet_collector.add_facet(&s);
}
}
let facet_handler =
multicollector.add_collector(facet_collector);
facets_requests.insert(key, facet_handler);
}
}
}
}

let (mut multifruit, hits) = {
if let Some(order_by) = order_by_field {
let field = get_field(&self.inner.index().schema(), order_by)?;
Expand Down Expand Up @@ -162,7 +215,52 @@ impl Searcher {
None => None,
};

Ok(SearchResult { hits, count })
let mut facets_result: BTreeMap<String, Vec<(String, u64)>> =
BTreeMap::new();

// Go though all collectors that are registered
for (key, facet_collector) in facets_requests {
let facet_count = facet_collector.extract(&mut multifruit);
let mut facet_vec = Vec::new();
if let Some(facets_dict) = facets {
match facets_dict.get_item(key.clone()) {
Some(facets_list_by_key) => {
if let Ok(facets_list_by_key_native) =
facets_list_by_key.downcast::<PyList>()
{
for facet_value in facets_list_by_key_native {
if let Ok(s) = facet_value.extract::<String>() {
let facet_value_vec: Vec<(
&tv::schema::Facet,
u64,
)> = facet_count.get(&s).collect();

// Go for all elements on facet and count to add on vector
for (
facet_value_vec_element,
facet_count,
) in facet_value_vec
{
facet_vec.push((
facet_value_vec_element.to_string(),
facet_count,
))
}
}
}
}
}
None => println!("Not found."),
}
}
facets_result.insert(key.clone(), facet_vec);
}

Ok(SearchResult {
hits,
count,
facets_result,
})
}

/// Returns the overall number of documents in the index.
Expand Down
65 changes: 54 additions & 11 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@


def schema():
return SchemaBuilder().add_text_field("title", stored=True).add_text_field("body").build()
return (
SchemaBuilder()
.add_text_field("title", stored=True)
.add_text_field("body")
.add_facet_field("facet")
.build()
)


def create_index(dir=None):
# assume all tests will use the same documents for now
Expand All @@ -27,6 +34,7 @@ def create_index(dir=None):
"now without taking a fish."
),
)
doc.add_facet("facet", tantivy.Facet.from_string("/mytag"))
writer.add_document(doc)
# 2 use the built-in json support
# keys need to coincide with field names
Expand Down Expand Up @@ -99,7 +107,9 @@ def test_simple_search_in_ram(self, ram_index):

def test_and_query(self, ram_index):
index = ram_index
query = index.parse_query("title:men AND body:summer", default_field_names=["title", "body"])
query = index.parse_query(
"title:men AND body:summer", default_field_names=["title", "body"]
)
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10)
Expand All @@ -114,27 +124,60 @@ def test_and_query(self, ram_index):

def test_and_query_parser_default_fields(self, ram_index):
query = ram_index.parse_query("winter", default_field_names=["title"])
assert repr(query) == """Query(TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114])))"""
assert (
repr(query)
== """Query(TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114])))"""
)

def test_and_query_parser_default_fields_undefined(self, ram_index):
query = ram_index.parse_query("winter")
query = ram_index.parse_query("/winter")
assert (
repr(query) == "Query(BooleanQuery { subqueries: ["
"(Should, TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114]))), "
"(Should, TermQuery(Term(field=1,bytes=[119, 105, 110, 116, 101, 114])))] "
"(Should, TermQuery(Term(field=1,bytes=[119, 105, 110, 116, 101, 114]))), "
"(Should, TermQuery(Term(field=2,bytes=[119, 105, 110, 116, 101, 114])))] "
"})"
)

def test_and_query_parser_default_fields_facets(self, ram_index):
index = ram_index
query = index.parse_query(
"old +facet:/mytag", default_field_names=["title", "body"]
)
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10)
assert result.count == 1

query = index.parse_query(
"old +facet:/wrong", default_field_names=["title", "body"]
)
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10)
assert result.count == 0

def test_search_facets(self, ram_index):
index = ram_index
query = index.parse_query("old", default_field_names=["title", "body"])
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10, facets={"facet": ["/"]})
assert result.count == 1
assert ("/mytag", 1) in result.facets["facet"]

def test_query_errors(self, ram_index):
index = ram_index
# no "bod" field
with pytest.raises(ValueError):
index.parse_query("bod:men", ["title", "body"])

def test_order_by_search(self):
schema = (SchemaBuilder()
schema = (
SchemaBuilder()
.add_unsigned_field("order", fast="single")
.add_text_field("title", stored=True).build()
.add_text_field("title", stored=True)
.build()
)

index = Index(schema)
Expand All @@ -155,15 +198,13 @@ def test_order_by_search(self):
doc.add_unsigned("order", 1)
doc.add_text("title", "Another test title")


writer.add_document(doc)

writer.commit()
index.reload()

query = index.parse_query("test")


searcher = index.searcher()

result = searcher.search(query, 10, offset=2, order_by_field="order")
Expand All @@ -187,9 +228,11 @@ def test_order_by_search(self):
assert searched_doc["title"] == ["Test title"]

def test_order_by_search_without_fast_field(self):
schema = (SchemaBuilder()
schema = (
SchemaBuilder()
.add_unsigned_field("order")
.add_text_field("title", stored=True).build()
.add_text_field("title", stored=True)
.build()
)

index = Index(schema)
Expand Down