Skip to content

Commit

Permalink
tested dspy's rag with 20shots optimization performs the same as Adal…
Browse files Browse the repository at this point in the history
…Flow's 4 shots
  • Loading branch information
liyin2015 committed Sep 19, 2024
1 parent 31ade8c commit 576cc94
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 22 deletions.
1 change: 1 addition & 0 deletions benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(self, passages_per_hop=3, model_client=None, model_kwargs=None):
data=task_desc_str,
role_desc="Task description for the language model",
param_type=adal.ParameterType.PROMPT,
requires_opt=False,
),
"few_shot_demos": adal.Parameter(
data=None,
Expand Down
8 changes: 5 additions & 3 deletions benchmarks/hotpot_qa/adal_exp/train_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ def train_diagnose(
def train(
train_batch_size=4, # larger batch size is not that effective, probably because of llm's lost in the middle
raw_shots: int = 0,
bootstrap_shots: int = 1,
bootstrap_shots: int = 4,
max_steps=1,
num_workers=4,
strategy="constrained",
optimization_order="sequential",
debug=False,
resume_from_ckpt=None,
exclude_input_fields_from_bootstrap_demos=False,
exclude_input_fields_from_bootstrap_demos=True,
):
adal_component = VallinaRAGAdal(
**gpt_3_model,
Expand Down Expand Up @@ -157,8 +157,10 @@ def train(
train(
debug=False,
max_steps=12,
resume_from_ckpt="/Users/liyin/.adalflow/ckpt/ValinaRAGAdal/random_max_steps_12_7c091_run_1.json",
# resume_from_ckpt="/Users/liyin/.adalflow/ckpt/ValinaRAGAdal/random_max_steps_12_7c091_run_1.json",
)
# random_max_steps_12_ecf16_run_9.json, demo only, val 0.6 to 0.68, test: 0.58-0.61
# random_max_steps_12_7c091_run_1.json, prompt + demo, 0.58 -0.62, test: 0.55 - 0.58
# resume from random_max_steps_12_7c091_run_1.json

# demo only, no input, 4 shots, 0.58-> 0.62, VallinaRAGAdal/constrained_max_steps_12_b0a37_run_1.json
4 changes: 2 additions & 2 deletions benchmarks/hotpot_qa/adal_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

def load_datasets():

trainset = HotPotQA(split="train", size=50)
trainset = HotPotQA(split="train", size=20)
valset = HotPotQA(split="val", size=50)
testset = HotPotQA(split="test", size=100)
testset = HotPotQA(split="test", size=50) # to keep the same as the dspy
print(f"trainset, valset: {len(trainset)}, {len(valset)}, example: {trainset[0]}")
return trainset, valset, testset

Expand Down
68 changes: 51 additions & 17 deletions benchmarks/hotpot_qa/dspy_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)

dspy.settings.configure(lm=turbo, rm=colbertv2_wiki17_abstracts)
from adalflow.eval.answer_match_acc import AnswerMatchAcc


def load_datasets():
Expand Down Expand Up @@ -70,6 +71,13 @@ def forward(self, question):


# pred: Prediction


def validate_answer(example, pred, trace=None):
evaluator = AnswerMatchAcc(type="fuzzy_match")
return evaluator.compute_single_item(pred.answer, example["answer"])


def validate_context_and_answer_and_hops(example, pred, trace=None):
# print(f"example: {example}, pred: {pred}, trace: {trace}")
if not dspy.evaluate.answer_exact_match(example, pred):
Expand Down Expand Up @@ -104,7 +112,8 @@ def train(trainset, save_path, filename):
if not os.path.exists(save_path):
os.makedirs(save_path)

teleprompter = BootstrapFewShot(metric=validate_context_and_answer_and_hops)
# teleprompter = BootstrapFewShot(metric=validate_context_and_answer_and_hops)
teleprompter = BootstrapFewShot(metric=validate_answer)
compiled_baleen = teleprompter.compile(
SimplifiedBaleen(),
teacher=SimplifiedBaleen(passages_per_hop=2),
Expand All @@ -129,21 +138,37 @@ def gold_passages_retrieved(example, pred, trace=None):

# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
evaluate_on_hotpotqa = Evaluate(
devset=devset, num_threads=1, display_progress=True, display_table=5
devset=devset,
num_threads=1,
display_progress=True,
display_table=5,
# metric=validate_answer,
)

uncompiled_baleen_retrieval_score = evaluate_on_hotpotqa(
uncompiled_baleen, metric=gold_passages_retrieved, display=False
uncompiled_baleen_answer_score = evaluate_on_hotpotqa(
uncompiled_baleen, metric=validate_answer, display_progress=True
)
print(f"## Answer Score for uncompiled Baleen: {uncompiled_baleen_answer_score}")

compiled_baleen_retrieval_score = evaluate_on_hotpotqa(
compiled_baleen, metric=gold_passages_retrieved
)
if compiled_baleen is None:
return

print(
f"## Retrieval Score for uncompiled Baleen: {uncompiled_baleen_retrieval_score}"
compiled_baleen_answer_score = evaluate_on_hotpotqa(
compiled_baleen, metric=validate_answer, display_progress=True
)
print(f"## Retrieval Score for compiled Baleen: {compiled_baleen_retrieval_score}")
print(f"## Answer Score for compiled Baleen: {compiled_baleen_answer_score}")

# uncompiled_baleen_retrieval_score = evaluate_on_hotpotqa(
# uncompiled_baleen, metric=gold_passages_retrieved, display=False
# )

# compiled_baleen_retrieval_score = evaluate_on_hotpotqa(
# compiled_baleen, metric=gold_passages_retrieved
# )

# print(
# f"## Retrieval Score for uncompiled Baleen: {uncompiled_baleen_retrieval_score}"
# )
# print(f"## Retrieval Score for compiled Baleen: {compiled_baleen_retrieval_score}")


if __name__ == "__main__":
Expand All @@ -155,17 +180,26 @@ def gold_passages_retrieved(example, pred, trace=None):

# Get the prediction. This contains `pred.context` and `pred.answer`.
uncompiled_baleen = SimplifiedBaleen() # uncompiled (i.e., zero-shot) program
pred = uncompiled_baleen(my_question)
# pred = uncompiled_baleen(my_question)

# Print the contexts and the answer.
print(f"Question: {my_question}")
print(f"Predicted Answer: {pred.answer}")
print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
turbo.inspect_history(n=3)
# # Print the contexts and the answer.
# print(f"Question: {my_question}")
# print(f"Predicted Answer: {pred.answer}")
# print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")
# turbo.inspect_history(n=3)

# Load the datasets.
trainset, devset = load_datasets()
from benchmarks.config import dspy_save_path

validate(
devset, uncompiled_baleen, uncompiled_baleen
) # dspy has 58.0% accuracy untrained. it is very slow at the inference, 3.58s per example

# train the model
compiled_baleen = train(trainset, dspy_save_path, "hotpotqa.json")
validate(devset, compiled_baleen, uncompiled_baleen)

# dspy 16 raw shots, 4 demos
# dspy supports multiple generators, in this case 3. Two query generator and one answer generator, they all choose the same examples.
# accuracy 62.0

0 comments on commit 576cc94

Please sign in to comment.