From 71783c2257f786e9110c4c571e6c422cdcda3a8e Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Thu, 7 Oct 2021 11:51:43 -0400 Subject: [PATCH] Batch extractor context queries, closes #120 --- src/python/txtai/pipeline/extractor.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/python/txtai/pipeline/extractor.py b/src/python/txtai/pipeline/extractor.py index 8e06ae925..9cde17a3d 100644 --- a/src/python/txtai/pipeline/extractor.py +++ b/src/python/txtai/pipeline/extractor.py @@ -86,6 +86,9 @@ def query(self, queries, texts): list of (id, text, score) """ + if not queries: + return [] + # Tokenize text segments, tokenlist = [], [] for text in texts: @@ -97,25 +100,24 @@ def query(self, queries, texts): # Add index id to segments to preserver ordering after filters segments = list(enumerate(segments)) + # Run batch queries for performance purposes + if isinstance(self.similarity, Similarity): + # Get list of (id, score) - sorted by highest score per query + scores = self.similarity(queries, [t for _, t in segments]) + else: + # Assume this is an embeddings instance, tokenize and run similarity queries + scores = self.similarity.batchsimilarity([self.tokenizer.tokenize(x) for x in queries], tokenlist) + # Build question-context pairs results = [] - for query in queries: + for i, query in enumerate(queries): # Get list of required and prohibited tokens must = [token.strip("+") for token in query.split() if token.startswith("+") and len(token) > 1] mnot = [token.strip("-") for token in query.split() if token.startswith("-") and len(token) > 1] # List of matches matches = [] - - # Get list of (id, score) - sorted by highest score - if isinstance(self.similarity, Similarity): - scores = self.similarity(query, [t for _, t in segments]) - else: - # Assume this is an embeddings instance, tokenize and run similarity query - query = self.tokenizer.tokenize(query) - scores = self.similarity.similarity(query, tokenlist) - - for x, score in scores: + for x, score in scores[i]: # Get segment text text = segments[x][1]