Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding facet filter and search #21

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::PyAny;
use pyo3::types::{PyAny, PyDict, PyList, PyTuple};

use crate::document::{extract_value, Document};
use crate::query::Query;
Expand Down Expand Up @@ -305,6 +305,7 @@ impl Index {
&self,
query: &str,
default_field_names: Option<Vec<String>>,
filters: Option<&PyDict>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be part of the QueryParser?
It feels like the query a query with (+yourfillter +<parsed_query>) could be built outside of this function.

I agree an helper would be nice, but this can be in external to the queyr parser I believe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that may not be needed, its easier as API and I had problems parsing fields that are not facets with a facets schema:

running this test:

    def test_and_query_parser_default_fields_undefined(self, ram_index):
        query = ram_index.parse_query("winter")

gives this tb:

tests/tantivy_test.py thread '<unnamed>' panicked at 'assertion failed: path.starts_with('/')', /Users/ramon/.cargo/registry/src/github.com-1ecc6299db9ec823/tantivy-0.12.0/src/schema/facet.rs:147:9
stack backtrace:
   0: <std::sys_common::backtrace::_print::DisplayBacktrace as core::fmt::Display>::fmt
   1: core::fmt::write
   2: std::io::Write::write_fmt
   3: std::panicking::default_hook::{{closure}}
   4: std::panicking::default_hook
   5: std::panicking::rust_panic_with_hook
   6: std::panicking::begin_panic
   7: <tantivy::schema::facet::Facet as core::convert::From<&T>>::from
   8: tantivy::schema::facet::Facet::from_text
   9: tantivy::query::query_parser::query_parser::QueryParser::compute_terms_for_string
  10: tantivy::query::query_parser::query_parser::QueryParser::compute_logical_ast_for_leaf
  11: tantivy::query::query_parser::query_parser::QueryParser::compute_logical_ast_from_leaf
  12: tantivy::query::query_parser::query_parser::QueryParser::compute_logical_ast_with_occur
  13: tantivy::query::query_parser::query_parser::QueryParser::compute_logical_ast
  14: tantivy::query::query_parser::query_parser::QueryParser::parse_query_to_logical_ast
  15: tantivy::query::query_parser::query_parser::QueryParser::parse_query
  16: tantivy::index::Index::parse_query
  17: tantivy::index::__init12805926066545533911::__init12805926066545533911::__wrap::{{closure}}
  18: tantivy::index::__init12805926066545533911::__init12805926066545533911::__wrap
  19: _PyMethodDef_RawFastCallKeywords
  20: _PyMethodDescr_FastCallKeywords
  21: call_function
  22: _PyEval_EvalFrameDefault
  23: _PyEval_EvalCodeWithName
  24: _PyFunction_FastCallDict
  25: _PyObject_Call_Prepend
  26: PyObject_Call
  27: _PyEval_EvalFrameDefault
  28: _PyEval_EvalCodeWithName
  29: _PyFunction_FastCallDict
  30: _PyEval_EvalFrameDefault
  31: _PyEval_EvalCodeWithName
  32: _PyFunction_FastCallKeywords
  33: call_function
  34: _PyEval_EvalFrameDefault
  35: _PyEval_EvalCodeWithName
  36: _PyFunction_FastCallKeywords
  37: call_function
  38: _PyEval_EvalFrameDefault
  39: function_code_fastcall
  40: call_function
  41: _PyEval_EvalFrameDefault
  42: _PyEval_EvalCodeWithName
  43: _PyFunction_FastCallDict
  44: _PyObject_Call_Prepend
  45: slot_tp_call
  46: _PyObject_FastCallKeywords
  47: call_function
  48: _PyEval_EvalFrameDefault
  49: function_code_fastcall
  50: call_function
  51: _PyEval_EvalFrameDefault
  52: function_code_fastcall
  53: _PyEval_EvalFrameDefault
  54: _PyEval_EvalCodeWithName
  55: _PyFunction_FastCallKeywords
  56: call_function
  57: _PyEval_EvalFrameDefault
  58: _PyEval_EvalCodeWithName
  59: _PyFunction_FastCallKeywords
  60: call_function
  61: _PyEval_EvalFrameDefault
  62: function_code_fastcall
  63: call_function
  64: _PyEval_EvalFrameDefault
  65: _PyEval_EvalCodeWithName
  66: _PyFunction_FastCallDict
  67: _PyObject_Call_Prepend
  68: slot_tp_call
  69: PyObject_Call
  70: _PyEval_EvalFrameDefault
  71: _PyEval_EvalCodeWithName
  72: _PyFunction_FastCallKeywords
  73: call_function
  74: _PyEval_EvalFrameDefault
  75: _PyEval_EvalCodeWithName
  76: _PyFunction_FastCallKeywords
  77: call_function
  78: _PyEval_EvalFrameDefault
  79: _PyEval_EvalCodeWithName
  80: _PyFunction_FastCallDict
  81: _PyEval_EvalFrameDefault
  82: _PyEval_EvalCodeWithName
  83: _PyFunction_FastCallKeywords
  84: call_function
  85: _PyEval_EvalFrameDefault
  86: _PyEval_EvalCodeWithName
  87: _PyFunction_FastCallKeywords
  88: call_function
  89: _PyEval_EvalFrameDefault
  90: function_code_fastcall
  91: _PyEval_EvalFrameDefault
  92: _PyEval_EvalCodeWithName
  93: _PyFunction_FastCallKeywords
  94: call_function
  95: _PyEval_EvalFrameDefault
  96: _PyEval_EvalCodeWithName
  97: _PyFunction_FastCallKeywords
  98: call_function
  99: _PyEval_EvalFrameDefault
 100: function_code_fastcall
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
fatal runtime error: failed to initiate panic, error 5
Fatal Python error: Aborted

Current thread 0x000000011193ddc0 (most recent call first):
  File "/Users/ramon/floss/tantivy-py/tests/tantivy_test.py", line 121 in test_and_query_parser_default_fields_undefined
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/python.py", line 166 in pytest_pyfunc_call
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/callers.py", line 187 in _multicall
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 87 in <lambda>
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 93 in _hookexec
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/hooks.py", line 286 in __call__
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/python.py", line 1435 in runtest
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/runner.py", line 131 in pytest_runtest_call
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/callers.py", line 187 in _multicall
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 87 in <lambda>
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 93 in _hookexec
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/hooks.py", line 286 in __call__
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/runner.py", line 207 in <lambda>
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/runner.py", line 234 in from_call
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/runner.py", line 207 in call_runtest_hook
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/runner.py", line 182 in call_and_report
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/runner.py", line 96 in runtestprotocol
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/runner.py", line 81 in pytest_runtest_protocol
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/callers.py", line 187 in _multicall
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 87 in <lambda>
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 93 in _hookexec
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/hooks.py", line 286 in __call__
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/main.py", line 270 in pytest_runtestloop
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/callers.py", line 187 in _multicall
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 87 in <lambda>
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 93 in _hookexec
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/hooks.py", line 286 in __call__
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/main.py", line 246 in _main
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/main.py", line 196 in wrap_session
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/main.py", line 239 in pytest_cmdline_main
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/callers.py", line 187 in _multicall
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 87 in <lambda>
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/manager.py", line 93 in _hookexec
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pluggy/hooks.py", line 286 in __call__
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/_pytest/config/__init__.py", line 92 in main
  File "/Users/ramon/.pyenv/versions/tantivy/lib/python3.7/site-packages/pytest/__main__.py", line 7 in <module>
  File "/Users/ramon/.pyenv/versions/3.7.7/lib/python3.7/runpy.py", line 85 in _run_code
  File "/Users/ramon/.pyenv/versions/3.7.7/lib/python3.7/runpy.py", line 193 in _run_module_as_main
[1]    27689 abort      RUST_BACKTRACE=1 pipenv run python -m pytest -s -k

) -> PyResult<Query> {
let mut default_fields = vec![];
let schema = self.index.schema();
Expand Down Expand Up @@ -336,6 +337,49 @@ impl Index {
let parser =
tv::query::QueryParser::for_index(&self.index, default_fields);
let query = parser.parse_query(query).map_err(to_pyerr)?;

if let Some(filters_dict) = filters {
let mut query_vec = Vec::new();
query_vec.push((tv::query::Occur::Must, query));
for key_value_any in filters_dict.items() {
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
if key_value.len() != 2 {
continue;
}
let key: String = key_value.get_item(0).extract()?;
let field = schema.get_field(&key).ok_or_else(|| {
exceptions::ValueError::py_err(format!(
"Field `{}` is not defined in the schema.",
key
))
})?;

if let Ok(value_list) =
key_value.get_item(1).downcast::<PyList>()
{
for value_element in value_list {
if let Ok(s) = value_element.extract::<String>() {
let facet = tv::schema::Facet::from_text(&s);
let term =
tv::schema::Term::from_facet(field, &facet);
let term_query = tv::query::TermQuery::new(
term,
tv::schema::IndexRecordOption::Basic,
);
let query: Box<dyn tv::query::Query> =
Box::new(term_query);
query_vec.push((tv::query::Occur::Must, query));
}
}
}
}
}
let boolean_query = tv::query::BooleanQuery::from(query_vec);
return Ok(Query {
inner: Box::new(boolean_query),
});
}

Ok(Query { inner: query })
}
}
115 changes: 111 additions & 4 deletions src/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use crate::query::Query;
use crate::{get_field, to_pyerr};
use pyo3::exceptions::ValueError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3::PyObjectProtocol;
use std::collections::BTreeMap;
use tantivy as tv;
use tantivy::collector::{Count, MultiCollector, TopDocs};

Expand Down Expand Up @@ -45,6 +47,7 @@ impl ToPyObject for Fruit {
/// Object holding a results successful search.
pub(crate) struct SearchResult {
hits: Vec<(Fruit, DocAddress)>,
facets_result: BTreeMap<String, Vec<(String, u64)>>,
#[pyo3(get)]
/// How many documents matched the query. Only available if `count` was set
/// to true during the search.
Expand All @@ -56,11 +59,17 @@ impl PyObjectProtocol for SearchResult {
fn __repr__(&self) -> PyResult<String> {
if let Some(count) = self.count {
Ok(format!(
"SearchResult(hits: {:?}, count: {})",
self.hits, count
"SearchResult(hits: {:?}, count: {}, facets: {})",
self.hits,
count,
self.facets_result.len()
))
} else {
Ok(format!("SearchResult(hits: {:?})", self.hits))
Ok(format!(
"SearchResult(hits: {:?}, facets: {})",
self.hits,
self.facets_result.len()
))
}
}
}
Expand All @@ -78,6 +87,14 @@ impl SearchResult {
.collect();
Ok(ret)
}

#[getter]
bloodbare marked this conversation as resolved.
Show resolved Hide resolved
fn facets(
&self,
_py: Python,
) -> PyResult<BTreeMap<String, Vec<(String, u64)>>> {
Ok(self.facets_result.clone())
}
}

#[pymethods]
Expand All @@ -94,6 +111,8 @@ impl Searcher {
/// should be ordered by. The field must be declared as a fast field
/// when building the schema. Note, this only works for unsigned
/// fields.
/// facets (PyDict, optional): A dictionary of facet fields and keys to
/// filter.
///
/// Returns `SearchResult` object.
///
Expand All @@ -106,6 +125,7 @@ impl Searcher {
limit: usize,
count: bool,
order_by_field: Option<&str>,
facets: Option<&PyDict>,
) -> PyResult<SearchResult> {
let mut multicollector = MultiCollector::new();

Expand All @@ -115,6 +135,37 @@ impl Searcher {
None
};

let mut facets_requests = BTreeMap::new();

// We create facets collector for each field and terms defined on the facets args
if let Some(facets_dict) = facets {
for key_value_any in facets_dict.items() {
if let Ok(key_value) = key_value_any.downcast::<PyTuple>() {
if key_value.len() != 2 {
continue;
}
let key: String = key_value.get_item(0).extract()?;
let field = get_field(&self.inner.index().schema(), &key)?;

let mut facet_collector =
tv::collector::FacetCollector::for_field(field);

if let Ok(value_list) =
key_value.get_item(1).downcast::<PyList>()
{
for value_element in value_list {
if let Ok(s) = value_element.extract::<String>() {
facet_collector.add_facet(&s);
}
}
let facet_handler =
multicollector.add_collector(facet_collector);
facets_requests.insert(key, facet_handler);
}
}
}
}

let (mut multifruit, hits) = {
if let Some(order_by) = order_by_field {
let field = get_field(&self.inner.index().schema(), order_by)?;
Expand Down Expand Up @@ -162,7 +213,52 @@ impl Searcher {
None => None,
};

Ok(SearchResult { hits, count })
let mut facets_result: BTreeMap<String, Vec<(String, u64)>> =
BTreeMap::new();

// Go though all collectors that are registered
for (key, facet_collector) in facets_requests {
let facet_count = facet_collector.extract(&mut multifruit);
let mut facet_vec = Vec::new();
if let Some(facets_dict) = facets {
match facets_dict.get_item(key.clone()) {
Some(facets_list_by_key) => {
if let Ok(facets_list_by_key_native) =
facets_list_by_key.downcast::<PyList>()
{
for facet_value in facets_list_by_key_native {
if let Ok(s) = facet_value.extract::<String>() {
let facet_value_vec: Vec<(
&tv::schema::Facet,
u64,
)> = facet_count.get(&s).collect();

// Go for all elements on facet and count to add on vector
for (
facet_value_vec_element,
facet_count,
) in facet_value_vec
{
facet_vec.push((
facet_value_vec_element.to_string(),
facet_count,
))
}
}
}
}
}
None => println!("Not found."),
}
}
facets_result.insert(key.clone(), facet_vec);
}

Ok(SearchResult {
hits,
count,
facets_result,
})
}

/// Returns the overall number of documents in the index.
Expand All @@ -171,6 +267,17 @@ impl Searcher {
self.inner.num_docs()
}

fn docn(&self, seg_doc: &PyTuple) -> PyResult<Document> {
bloodbare marked this conversation as resolved.
Show resolved Hide resolved
let seg: u32 = seg_doc.get_item(0).extract()?;
let doc: u32 = seg_doc.get_item(1).extract()?;
let address = tv::DocAddress(seg, doc);
let doc = self.inner.doc(address).map_err(to_pyerr)?;
let named_doc = self.inner.schema().to_named_doc(&doc);
Ok(Document {
field_values: named_doc.0,
})
}

/// Fetches a document from Tantivy's store given a DocAddress.
///
/// Args:
Expand Down
33 changes: 30 additions & 3 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def schema():
return SchemaBuilder().add_text_field("title", stored=True).add_text_field("body").build()
return SchemaBuilder().add_text_field("title", stored=True).add_text_field("body").add_facet_field("facet").build()

def create_index(dir=None):
# assume all tests will use the same documents for now
Expand All @@ -27,6 +27,7 @@ def create_index(dir=None):
"now without taking a fish."
),
)
doc.add_facet('facet', tantivy.Facet.from_string("/mytag"))
writer.add_document(doc)
# 2 use the built-in json support
# keys need to coincide with field names
Expand Down Expand Up @@ -117,14 +118,40 @@ def test_and_query_parser_default_fields(self, ram_index):
assert repr(query) == """Query(TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114])))"""

def test_and_query_parser_default_fields_undefined(self, ram_index):
query = ram_index.parse_query("winter")
query = ram_index.parse_query("/winter")
assert (
repr(query) == "Query(BooleanQuery { subqueries: ["
"(Should, TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114]))), "
"(Should, TermQuery(Term(field=1,bytes=[119, 105, 110, 116, 101, 114])))] "
"(Should, TermQuery(Term(field=1,bytes=[119, 105, 110, 116, 101, 114]))), "
"(Should, TermQuery(Term(field=2,bytes=[119, 105, 110, 116, 101, 114])))] "
"})"
)

def test_and_query_parser_default_fields_facets(self, ram_index):
index = ram_index
query = index.parse_query("old", default_field_names=["title", "body"], filters={"facet": ["/mytag"]})
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10)
assert result.count == 1

query = index.parse_query("old", default_field_names=["title", "body"], filters={"facet": ["/wrongtag"]})
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10)
assert result.count == 0

def test_search_facets(self, ram_index):
index = ram_index
query = index.parse_query("old", default_field_names=["title", "body"])
# look for an intersection of documents
searcher = index.searcher()
result = searcher.search(query, 10, facets={"facet": ["/"]})
assert result.count == 1
assert ('/mytag', 1) in result.facets['facet']
bloodbare marked this conversation as resolved.
Show resolved Hide resolved


bloodbare marked this conversation as resolved.
Show resolved Hide resolved

def test_query_errors(self, ram_index):
index = ram_index
# no "bod" field
Expand Down