Skip to content

Commit

Permalink
Fixes #78 Reading comprehension synthetic data regex improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tleyden committed Dec 4, 2023
1 parent 567c910 commit ba88453
Show file tree
Hide file tree
Showing 2 changed files with 319 additions and 46 deletions.
168 changes: 122 additions & 46 deletions dalm/datasets/reading_comprehension_generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,38 +140,59 @@ def create_domain_tokenizer_from_files(directory_or_file: str, csv_column: Optio
return create_domain_tokenizer(os.path.join(temp_dir, "temp.txt"))


def fix_first_prompt(text: str, chat_chain: List[Dict[str, str]]) -> List[Dict[str, str]]:
# remove the first prompt
first_prompt = chat_chain.pop(0)
fixed_first_prompt = [
{
"content": f"Based on the following text: \n {text}, \n I'd like you to answer a few questions\n"
+ first_prompt["content"],
"role": "user",
}
]
return fixed_first_prompt + chat_chain
def wrap_context_with_rag_instruction(context: str) -> str:
return f"Based on the following text: \n {context}, \n I'd like you to answer a few questions\n"


# TODO: add test
# TODO: Address known issues described in #78
def question_and_answer_extractor(whole_text: str, context: str) -> List[Dict[str, str]] | None:
text_lines = whole_text.split("\n")
question: List[str] = []
answer: List[str] = []
def extract_question(text: str) -> Tuple[bool, str]:
"""
Extracts a question from a line of text.
Returns a tuple of (is_question, question_text)
"""
# question regex
return extract_question_or_answer(text, extract_type="question")

def extract_answer(text: str) -> Tuple[bool, str]:
"""
Extracts an answer from a line of text.
Returns a tuple of (is_answer, answer_text)
"""
# question regex
return extract_question_or_answer(text, extract_type="answer")

question_context = False
answer_context = False
def extract_question_or_answer(text: str, extract_type: str = "question") -> Tuple[bool, str]:

# Match a line that starts with any number of junk characters, followed by either "question:"
# or "answer:", followed by any number of spaces (ignored), followed by any number of characters
# that will be captured in a group as the question or answer.
# extraction_regex = rf".*{extract_type}:\s*(.*)"

# Update above to handle the case where the question or answer is in brackets, with
# other text to be ignored inside the brackets
extraction_regex = rf".*\[?{extract_type}[:\]]*(?:.*?\])?\s*(.*)"

match = re.match(extraction_regex, text, re.IGNORECASE)
extracted_text = match.group(1) if match else None
found_extracted = True if extracted_text else False
return found_extracted, extracted_text


def _raw_question_and_answer_extractor(whole_text: str) -> List[Dict[str, str]] | None:
"""
Loop over all lines in the text.
When we find a question, capture the question into a variable and set a state flag
When we find an answer, capture the answer into a variable and save the QA pair
When we run out of lines, return the list of QA pairs
"""

result = []
task_regex = r"^\*?\*?task\s*\d*"

# question regex
question_regex = r"^question\s*\d*"
cur_qa_pair = {}
qa_pairs = []

# answer regex
answer_regex = r"^answer\s*\d*"
state = "waiting_for_question" # waiting_for_question, waiting_for_answer

text_lines = whole_text.split("\n")
for i in text_lines:
raw_text = i.strip()
text = raw_text.lower()
Expand All @@ -180,31 +201,86 @@ def question_and_answer_extractor(whole_text: str, context: str) -> List[Dict[st
if text == "":
continue

# if the line start matches the question regex or the task regex
if re.match(question_regex, text) or re.match(task_regex, text):
if answer_context:
result.append({"content": " ".join(question), "role": "user"})
result.append({"content": " ".join(answer), "role": "assistant"})
question = []
answer = []
answer_context = False
# If the line matches the task regex, print a warning. The old code handled
# "tasks", but this new code does not. Need to inspect where these come into play
if re.match(task_regex, text):
logger.warning(f"Found a task line: {text}")

if state == "waiting_for_question":
is_question, question_text = extract_question(text)
if is_question:
state = "waiting_for_answer"
cur_qa_pair = {"question": question_text, "answer": "TBD"}
continue
elif state == "waiting_for_answer":
is_answer, answer_text = extract_answer(text)
state = "waiting_for_question"
cur_qa_pair["answer"] = answer_text
qa_pairs.append(cur_qa_pair)
continue

else:
raise ValueError("Unknown state")

return qa_pairs


def convert_qa_pairs_to_chat_completions(qa_pairs: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""
Convert a list of QA pairs into a list of chat completions that can be fed into the large language model.
"""
chat_completions = []
for qa_pair in qa_pairs:
question = qa_pair["question"]
answer = qa_pair["answer"]

question_chat_completion = {
"content": question,
"role": "user",
}

answer_chat_completion = {
"content": answer,
"role": "assistant",
}

chat_completions.append(question_chat_completion)
chat_completions.append(answer_chat_completion)

return chat_completions

def question_and_answer_extractor(whole_text: str, context: str) -> List[Dict[str, str]] | None:
"""
Extracts questions and answers from the raw text generated by the large language model.
@param whole_text: the raw questions and answers generated by the large language model, eg:
"1. QUESTION: Can you summarize the .. ?
ANSWER: Population imaging studies generated .."
@param context: the full dataset text that was used to generate the questions and answers, eg:
"Population imaging studies generate data for developing and implementing..."
"""

chat_completion_inputs = []

question_context = True
answer_context = False
# Wrap the context with a RAG instruction
context_instruction = wrap_context_with_rag_instruction(context)

if re.match(answer_regex, text):
question_context = False
answer_context = True
# The first chat completion input is the context instruction
first_chat_completion_input = {
"content": context_instruction,
"role": "user",
}
chat_completion_inputs.append(first_chat_completion_input)

if question_context:
# remove (labelled as QUESTION and ANSWER) from the text
raw_text = re.sub(r"\(labelled as QUESTION and ANSWER\)", "", raw_text)
question.append(raw_text)
# Extract the qa pairs from whole_text
qa_pairs = _raw_question_and_answer_extractor(whole_text)

if answer_context:
answer.append(raw_text)
# Convert the qa pairs to chat completion inputs
qa_pairs_chat_completions = convert_qa_pairs_to_chat_completions(qa_pairs)

if result == []:
return None
# Add the qa pairs chat completions to the result
chat_completion_inputs.extend(qa_pairs_chat_completions)

return fix_first_prompt(context, result)
return chat_completion_inputs
Loading

0 comments on commit ba88453

Please sign in to comment.