diff --git a/src/python/txtai/pipeline/extractor.py b/src/python/txtai/pipeline/extractor.py index a1ca52fba..7e6cd08cf 100644 --- a/src/python/txtai/pipeline/extractor.py +++ b/src/python/txtai/pipeline/extractor.py @@ -13,7 +13,7 @@ class Extractor(Pipeline): Class that uses an extractive question-answering model to extract content from a given text context. """ - def __init__(self, similarity, path, quantize=False, gpu=True, model=None, tokenizer=None, minscore=None, mintokens=None, topn=None): + def __init__(self, similarity, path, quantize=False, gpu=True, model=None, tokenizer=None, minscore=None, mintokens=None, context=None): """ Builds a new extractor. @@ -26,7 +26,7 @@ def __init__(self, similarity, path, quantize=False, gpu=True, model=None, token tokenizer: Tokenizer class minscore: minimum score to include context match, defaults to None mintokens: minimum number of tokens to include context match, defaults to None - topn: topn context matches to include, defaults to 3 + context: topn context matches to include, defaults to 3 """ # Similarity instance @@ -45,7 +45,7 @@ def __init__(self, similarity, path, quantize=False, gpu=True, model=None, token self.mintokens = mintokens if mintokens is not None else 0.0 # Top N context matches to include for question-answering - self.topn = topn if topn else 3 + self.context = context if context else 3 def __call__(self, queue, texts): """ @@ -68,7 +68,7 @@ def __call__(self, queue, texts): names, questions, contexts, topns, snippets = [], [], [], [], [] for x, (name, _, question, snippet) in enumerate(queue): # Build context using top n best matching segments - topn = sorted(results[x], key=lambda y: y[2], reverse=True)[: self.topn] + topn = sorted(results[x], key=lambda y: y[2], reverse=True)[: self.context] context = " ".join([text for _, text, _ in sorted(topn, key=lambda y: y[0])]) names.append(name)