Skip to content

Commit

Permalink
solve #4
Browse files Browse the repository at this point in the history
  • Loading branch information
xingxuanli committed Jun 3, 2024
1 parent 018ad7b commit dfb85cf
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions utils/retrieval/flashcard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# flashcard
from typing import Optional, Any
import transformers
import torch
from peft import PeftModel
from transformers.utils import is_accelerate_available, is_bitsandbytes_available
from transformers import (
AutoModel,
AutoTokenizer,
AutoModelForCausalLM,
GenerationConfig,
pipeline,
)
import re
import utils.globalvar
import datasets


def formatting_prompts_func(ipt):
text = f"### Instruction: Answer the question truthfully.\n### Input: {ipt}\n### Output: "
return text

### Query Generation ###############################################
def llama2_pipeline(prompt):
base_model = "meta-llama/Llama-2-7b-hf"
peft_model = "veggiebird/llama-2-7b-medical-flashcards-8bit"

# load the model only once
if utils.globalvar.bio_model is None:
utils.globalvar.bio_model = AutoModelForCausalLM.from_pretrained(
base_model,
use_safetensors=True,
torch_dtype=torch.float16,
load_in_8bit=True
)

utils.globalvar.bio_model = PeftModel.from_pretrained(utils.globalvar.bio_model, peft_model)

utils.globalvar.bio_tokenizer = AutoTokenizer.from_pretrained(base_model)

print("Model loaded...")
pipeline = transformers.pipeline(
"text-generation",
model=utils.globalvar.bio_model,
tokenizer=utils.globalvar.bio_tokenizer,
torch_dtype=torch.float16,
device_map="auto",
)

sequences = pipeline(
prompt,
do_sample=False,
top_k=10,
num_return_sequences=1,
eos_token_id=utils.globalvar.bio_tokenizer.eos_token_id,
max_length=256,
)

return sequences[0]["generated_text"].strip()

###############################################


### Query Knowl. ###############################################
def extract_responses(content):
pattern = r"### Output:(.+?)###"
matches = re.findall(pattern, content, re.DOTALL)
return [match.strip() for match in matches]


def generate_flashcard_query(input):
prompt = formatting_prompts_func(input)
query = llama2_pipeline(prompt)
processed_query = extract_responses(query)
return query, processed_query


def execute_flashcard_query(query):
model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-roberta-large")
tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-roberta-large")
dataset = datasets.load_dataset('veggiebird/medical-flashcards')
dataset = dataset["train"]
dataset.add_faiss_index(column='embeddings')

query_inputs = tokenizer(query, padding=True, truncation=True, return_tensors="pt")
query_embedding = model(**query_inputs, output_hidden_states=True, return_dict=True).pooler_output.detach().numpy()
scores, retrieved_Examples = dataset.get_nearest_examples("embeddings", query_embedding, k=1)
pre_knowl = retrieved_Examples["output"][0].strip()
try:
knowl = ' '.join(re.split(r'(?<=[.:;])\s', pre_knowl)[:3])
except:
knowl = pre_knowl
return knowl

###############################################


def retrieve_flashcard_knowledge(input, data_point):
knowl = ""
print("Generate query...")
query, processed_query = generate_flashcard_query(input)
if len(processed_query) != 0:
print("Query:", processed_query[0])
print("Retrieve knowledge...")
knowl = execute_flashcard_query(processed_query[0])
print(knowl)
return knowl

0 comments on commit dfb85cf

Please sign in to comment.