-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract.py
22 lines (19 loc) · 869 Bytes
/
extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, QuestionAnsweringPipeline
import operator
class Reader:
def __init__(self, model_name):
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForQuestionAnswering.from_pretrained(self.model_name)
self.pipe = QuestionAnsweringPipeline(model=self.model, tokenizer=self.tokenizer, device=0)
def extract(self, question, passages):
answers = []
for passage in passages:
try:
answer = self.pipe(question=question, context=passage)
answer['text'] = passage
answers.append(answer)
except KeyError:
pass
answers.sort(key=operator.itemgetter('score'), reverse=True)
return answers