Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
MacOS committed May 5, 2024
2 parents 482605c + a563e08 commit 65a4d1d
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 25 deletions.
2 changes: 1 addition & 1 deletion llmware/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ numpy>=1.23.2
openai>=1.0
pdf2image==1.16.0
pymilvus>=2.3.0
pymongo==4.5.0
pymongo>=4.7.0
pytesseract==0.3.10
sentence-transformers==2.2.2
tabulate==0.9.0
Expand Down
59 changes: 36 additions & 23 deletions llmware/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""The retrieval module implements the Query class.

The query class executes queries against vector databases and depends on a library object.
"""
"""The retrieval module implements the Query class. The Query class provides a high-level interface for executing
a variety of queries on a Library collection, whether instantiated on Mongo, Postgres, or SQLite.
The Query class includes both text retrieval strategies, which operate directly as queries on the text collection
database, as well as vector embedding semantic retrieval strategies, which require the use of o vector DB and that the
embeddings were previously created for the Library. There are also a number of convenience methods that provide
'hybrid' strategies combining elements of semantic and text querying."""


import logging
Expand All @@ -33,14 +37,15 @@


class Query:
"""Implements the query capabilities against a ``library``.

"""Implements the query capabilities against a ``Library` object`.
Query is responsible for executing queries against an indexed library. The library can be semantic, text, custom,
or hybrid. A query object requires a library object as input, which will be the source of the query.
Parameters
----------
library : object
library : Library object
A ``library`` object.
embedding_model : object, default=None
Expand Down Expand Up @@ -103,6 +108,7 @@ class Query:
"Federal Constitutional Law of 1920. The political system of the Second Republic with its nine federal "
"states is based on the constitution of 1920, amended in 1929, which was re-enacted on 1 May 1945. [108] "
"""

def __init__(self, library, embedding_model=None, tokenizer=None, vector_db_api_key=None,
query_id=None, from_hf=False, from_sentence_transformer=False,embedding_model_name=None,
save_history=True, query_mode=None, vector_db=None, model_api_key=None):
Expand Down Expand Up @@ -407,7 +413,6 @@ def query(self, query, query_type="text", result_count=20, results_only=True):

return output_result

# basic simple text query method - only requires entering the query
def text_query (self, query, exact_mode=False, result_count=20, exhaust_full_cursor=False, results_only=True):

""" Execute a basic text query. """
Expand Down Expand Up @@ -693,7 +698,6 @@ def _cursor_to_qr (self, query, cursor_results, result_count=20, exhaust_full_cu

return qr_dict

# basic semantic query
def semantic_query(self, query, result_count=20, embedding_distance_threshold=None, custom_filter=None, results_only=True):

""" Main method to execute a semantic query - only required parameter is the query. """
Expand Down Expand Up @@ -850,7 +854,8 @@ def similar_blocks_embedding(self, block, result_count=20, embedding_distance_th

return results_dict

def dual_pass_query(self, query, result_count=20, primary="text", safety_check=True, custom_filter=None, results_only=True):
def dual_pass_query(self, query, result_count=20, primary="text",
safety_check=True, custom_filter=None, results_only=True):

""" Executes a combination of text and semantic queries and attempts to interweave and re-rank based on
correspondence between the two query attempts. """
Expand All @@ -872,12 +877,14 @@ def dual_pass_query(self, query, result_count=20, primary="text", safety_check=T
# run dual pass - text + semantic
# Choose appropriate text query method based on custom_filter
if custom_filter:
retrieval_dict_text = self.text_query_with_custom_filter(query, custom_filter, result_count=result_count, results_only=True)
retrieval_dict_text = self.text_query_with_custom_filter(query, custom_filter,
result_count=result_count, results_only=True)
else:
retrieval_dict_text = self.text_query(query, result_count=result_count, results_only=True)

# Semantic query with custom filter
retrieval_dict_semantic = self.semantic_query(query, result_count=result_count, custom_filter=custom_filter, results_only=True)
retrieval_dict_semantic = self.semantic_query(query, result_count=result_count,
custom_filter=custom_filter, results_only=True)

if primary == "text":
first_list = retrieval_dict_text
Expand Down Expand Up @@ -930,17 +937,16 @@ def dual_pass_query(self, query, result_count=20, primary="text", safety_check=T
doc_fn_list.append(qr["file_source"])

retrieval_dict = {"results": merged_results,
"text_results": retrieval_dict_text,
"semantic_results": retrieval_dict_semantic,
"doc_ID": doc_id_list,
"file_source": doc_fn_list}
"text_results": retrieval_dict_text,
"semantic_results": retrieval_dict_semantic,
"doc_ID": doc_id_list,
"file_source": doc_fn_list}

if results_only:
return merged_results

return retrieval_dict


def augment_qr (self, query_result, query_topic, augment_query="semantic"):

""" Augments the set of query results using alternative retrieval strategy. """
Expand Down Expand Up @@ -1038,7 +1044,9 @@ def page_lookup(self, page_list=None, doc_id_list=None, text_only=False):
if "doc_ID" in doc_id_list:
doc_id_list = doc_id_list["doc_ID"]
else:
logging.warning("warning: could not recognize doc id list requested. by default, will set to all documents in the library collection.")
logging.warning("warning: could not recognize doc id list requested. by default, "
"will set to all documents in the library collection.")

doc_id_list = self.list_doc_id()

if not page_list:
Expand All @@ -1051,7 +1059,8 @@ def page_lookup(self, page_list=None, doc_id_list=None, text_only=False):
else:
page_dict = {"doc_ID": {"$in":doc_id_list}, "master_index": {"$in": page_list}}

cursor_results = CollectionRetrieval(self.library_name, account_name=self.account_name).filter_by_key_dict(page_dict)
cursor_results = CollectionRetrieval(self.library_name,
account_name=self.account_name).filter_by_key_dict(page_dict)

output = []

Expand All @@ -1064,12 +1073,12 @@ def page_lookup(self, page_list=None, doc_id_list=None, text_only=False):

return output

# new method to extract whole library
def get_whole_library(self, selected_keys=None):

""" Gets the whole library - and will return as a list in-memory. """

match_results_cursor = CollectionRetrieval(self.library_name, account_name=self.account_name).get_whole_collection()
match_results_cursor = CollectionRetrieval(self.library_name,
account_name=self.account_name).get_whole_collection()

match_results = match_results_cursor.pull_all()

Expand Down Expand Up @@ -1098,7 +1107,6 @@ def get_whole_library(self, selected_keys=None):

return qr

# new method to generate csv files for each table entry
def export_all_tables(self, query="", output_fp=None):

""" Exports all tables, with query option to limit the list from a library. """
Expand All @@ -1110,7 +1118,8 @@ def export_all_tables(self, query="", output_fp=None):

if not query:

match_results = CollectionRetrieval(self.library_name, account_name=self.account_name).filter_by_key("content_type","table")
match_results = CollectionRetrieval(self.library_name,
account_name=self.account_name).filter_by_key("content_type","table")

else:
kv_dict = {"content_type": "table"}
Expand Down Expand Up @@ -1274,7 +1283,8 @@ def list_doc_fn(self):

""" Utility function - returns list of all document names in the library. """

doc_fn_raw_list = CollectionRetrieval(self.library_name, account_name=self.account_name).get_distinct_list("file_source")
doc_fn_raw_list = CollectionRetrieval(self.library_name,
account_name=self.account_name).get_distinct_list("file_source")

doc_fn_out = []
for i, file in enumerate(doc_fn_raw_list):
Expand Down Expand Up @@ -1712,15 +1722,18 @@ def expand_text_result_after(self, block, window_size=400):
return output

def generate_csv_report(self):

"""Generates a csv report from the current query status. """

output = QueryState(self).generate_query_report_current_state()
return output

def filter_by_key_value_range(self, key, value_range, results_only=True):

""" Executes a filter by key value range. """

cursor = CollectionRetrieval(self.library_name, account_name=self.account_name).filter_by_key_value_range(key,value_range)
cursor = CollectionRetrieval(self.library_name,
account_name=self.account_name).filter_by_key_value_range(key,value_range)

query= ""
result_dict = self._cursor_to_qr(query, cursor, exhaust_full_cursor=True)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def glob_fix(package_name, glob):
'openai>=1.0.0',
'pdf2image==1.16.0',
'pymilvus>=2.3.0',
'pymongo==4.5.0',
'pymongo>=4.7.0',
'pytesseract==0.3.10',
'sentence-transformers==2.2.2',
'tabulate==0.9.0',
Expand Down
Binary file added wheel_archives/llmware-0.2.12-py3-none-any.whl
Binary file not shown.

0 comments on commit 65a4d1d

Please sign in to comment.