Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for svamp #52

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions finetuning/training_configs/few_shot/svamp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
seed_everything: 333
trainer:
default_root_dir: &exp_name results/debug-tmp
# progress_bar_refresh_rate: 1
num_sanity_val_steps: 0
log_every_n_steps: 1
logger+:
- class_path: finetuning.lightning_modules.patches.patched_loggers.PatchedWandbLogger
init_args:
entity: yale-lily
project: unified-codegen
save_dir: *exp_name
name: *exp_name
log_model: False
save_code: True
offline: False
# offline: True
callbacks+:
- class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar
init_args:
refresh_rate: 1

accelerator: gpu
devices: 2
# strategy: deepspeed_stage_2
strategy: ddp_find_unused_parameters_false
precision: 16

model:
class_path: lightning_modules.models.seq2seq_model.Seq2SeqModel
init_args:
transformer_model_name: default-will-cause-error
executor_cls: execution.executors.MathExecutor
max_gen_len: 256
sampling_temp: 0.001
# sampling_temp_at_k: 0.8
# pass_at_k: 50
# max_generation_batches: 5
gradient_ckpt: false
save_raw_generation_results: true
# print_eval_every_n_batches: 1

data:
class_path: lightning_modules.datasets.base_datamodule.FewShotNL2CodeDataModule
init_args:
transformer_model_name: default-will-cause-error
dataset_cls: FewShotMathQADataset
batch_size: 1
val_batch_size: 4
## prompting settings
prompting_init_args:
exemplar_file_path: prompt_files/svamp-idiomatic_code-annotated-4_exemplars.jsonl
num_exemplars: 4
fixed_exemplars: true
exemplar_selection_method: first
add_instruction: true
use_chat_format: false
# val_max_instances: 64
val_set_init_args:
file_path: data/svamp/svamp_test.jsonl
80 changes: 80 additions & 0 deletions preprocessing/preprocess_svamp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Preprocessing script for SVAMP.

A typical example of SVAMP look like this:
{
"ID": "chal-1",
"Body": "Each pack of dvds costs 76 dollars. If there is a discount of 25 dollars on each pack",
"Question": "How much do you have to pay to buy each pack?",
"Equation": "( 76.0 - 25.0 )",
"Answer": 51.0,
"Type": "Subtraction"
},

And after preprocessing, we want it to look like this:
{
"question": "Each pack of dvds costs 76 dollars. If there is a discount of 25 dollars on each pack. How much do you have to pay to buy each pack?",
"answer": 51.0,
"annotated_code": <only available for prompt examples>,
"metadata": {
"ID": "chal-1",
"Body": "Each pack of dvds costs 76 dollars. If there is a discount of 25 dollars on each pack",
"Question": "How much do you have to pay to buy each pack?",
"Equation": "( 76.0 - 25.0 )",
"Answer": 51.0,
"Type": "Subtraction"
},
}

"""

import json

from typing import Dict, List, Any


ANNOTATION_DICT = {
"chal-10": "n_customers_left = 9\nn_customers_now = 12\nn_customers_start = n_customers_now + n_customers_left\nanswer = n_customers_start",
"chal-11": "n_birds = 3\nn_storks = 6\nn_more_bird = 2\nn_more_stork_than_bird = n_storks - (n_birds + n_more_bird)\nanswer = n_more_stork_than_bird",
"chal-12": "n_tables = 11\nn_chairs_per_table = 13\nn_chairs = n_tables * n_chairs_per_table\nanswer = n_chairs",
"chal-23": "group_size = 18\nn_total_bananas = 180\nn_groups = n_total_bananas / group_size\nanswer = n_groups",
}

def preprocess_svamp_instance(example: Dict[str, Any]) -> Dict[str, Any]:
# preprocess based on the example
preprocessed_example = {}
preprocessed_example["question"] = example["Body"] + (" " if example["Body"].endswith(".") else ". ") \
+ example["Question"]
preprocessed_example["answer"] = example["Answer"]
preprocessed_example["metadata"] = example

return preprocessed_example

def main():
with open("data/svamp/SVAMP.json", "r") as f:
examples = json.load(f)

print(f"loaded {len(examples)} examples")

# preprocess the examples
processed_examples = [preprocess_svamp_instance(example) for example in examples]

# split the examples to prompt and test sets
prompt_examples = list(filter(lambda x: x["metadata"]["ID"] in ANNOTATION_DICT, processed_examples))
test_examples = list(filter(lambda x: x["metadata"]["ID"] not in ANNOTATION_DICT, processed_examples))

# save the program annotations to the prompt examples
for example in prompt_examples:
example["annotated_code"] = ANNOTATION_DICT[example["metadata"]["ID"]]

# save the prompt and test sets
print(f"Saving {len(prompt_examples)} prompt examples and {len(test_examples)} test examples")
with open("prompt_files/svamp-idiomatic_code-annotated-4_exemplars.jsonl", "w+") as f:
for example in prompt_examples:
f.write(json.dumps(example) + "\n")

with open("data/svamp/svamp_test.jsonl", "w+") as f:
for example in test_examples:
f.write(json.dumps(example) + "\n")

if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions prompt_files/svamp-idiomatic_code-annotated-4_exemplars.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"question": "A waiter had some customers. After 9 customers left he still had 12 customers. How many customers did he have at the start?", "answer": 21.0, "metadata": {"ID": "chal-10", "Body": "A waiter had some customers. After 9 customers left he still had 12 customers.", "Question": "How many customers did he have at the start?", "Equation": "( 9.0 + 12.0 )", "Answer": 21.0, "Type": "Addition"}, "annotated_code": "n_customers_left = 9\nn_customers_now = 12\nn_customers_start = n_customers_now + n_customers_left\nanswer = n_customers_start"}
{"question": "3 birds were sitting on the fence. 6 more storks and 2 more birds came to join them. How many more storks than birds are sitting on the fence?", "answer": 1.0, "metadata": {"ID": "chal-11", "Body": "3 birds were sitting on the fence. 6 more storks and 2 more birds came to join them.", "Question": "How many more storks than birds are sitting on the fence?", "Equation": "( 6.0 - ( 3.0 + 2.0 ) )", "Answer": 1.0, "Type": "Subtraction"}, "annotated_code": "n_birds = 3\nn_storks = 6\nn_more_bird = 2\nn_more_stork_than_bird = n_storks - (n_birds + n_more_bird)\nanswer = n_more_stork_than_bird"}
{"question": "They decided to hold the party in their backyard. If they have 11 sets of tables and each set has 13 chairs. How many chairs do they have in the backyard?", "answer": 143.0, "metadata": {"ID": "chal-12", "Body": "They decided to hold the party in their backyard. If they have 11 sets of tables and each set has 13 chairs", "Question": "How many chairs do they have in the backyard?", "Equation": "( 11.0 * 13.0 )", "Answer": 143.0, "Type": "Multiplication"}, "annotated_code": "n_tables = 11\nn_chairs_per_table = 13\nn_chairs = n_tables * n_chairs_per_table\nanswer = n_chairs"}
{"question": "The bananas in Philip's collection are organized into groups of size 18. If there are a total of 180 bananas in Philip's banana collection. How many groups are there?", "answer": 10.0, "metadata": {"ID": "chal-23", "Body": "The bananas in Philip's collection are organized into groups of size 18. If there are a total of 180 bananas in Philip's banana collection", "Question": "How many groups are there?", "Equation": "( 180.0 / 18.0 )", "Answer": 10.0, "Type": "Common-Division"}, "annotated_code": "group_size = 18\nn_total_bananas = 180\nn_groups = n_total_bananas / group_size\nanswer = n_groups"}
4 changes: 4 additions & 0 deletions svamp-idiomatic_code-annotated-4_exemplars.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"question": "A waiter had some customers. After 9 customers left he still had 12 customers. How many customers did he have at the start?", "answer": 21.0, "metadata": {"ID": "chal-10", "Body": "A waiter had some customers. After 9 customers left he still had 12 customers.", "Question": "How many customers did he have at the start?", "Equation": "( 9.0 + 12.0 )", "Answer": 21.0, "Type": "Addition"}, "annotated_code": "n_customers_left = 9\nn_customers_now = 12\nn_customers_start = n_customers_now + n_customers_left\nanswer = n_customers_start"}
{"question": "3 birds were sitting on the fence. 6 more storks and 2 more birds came to join them. How many more storks than birds are sitting on the fence?", "answer": 1.0, "metadata": {"ID": "chal-11", "Body": "3 birds were sitting on the fence. 6 more storks and 2 more birds came to join them.", "Question": "How many more storks than birds are sitting on the fence?", "Equation": "( 6.0 - ( 3.0 + 2.0 ) )", "Answer": 1.0, "Type": "Subtraction"}, "annotated_code": "n_birds = 3\nn_storks = 6\nn_more_bird = 2\nn_more_stork_than_bird = n_storks - (n_birds + n_more_bird)\nanswer = n_more_stork_than_bird"}
{"question": "They decided to hold the party in their backyard. If they have 11 sets of tables and each set has 13 chairs. How many chairs do they have in the backyard?", "answer": 143.0, "metadata": {"ID": "chal-12", "Body": "They decided to hold the party in their backyard. If they have 11 sets of tables and each set has 13 chairs", "Question": "How many chairs do they have in the backyard?", "Equation": "( 11.0 * 13.0 )", "Answer": 143.0, "Type": "Multiplication"}, "annotated_code": "n_tables = 11\nn_chairs_per_table = 13\nn_chairs = n_tables * n_chairs_per_table\nanswer = n_chairs"}
{"question": "The bananas in Philip's collection are organized into groups of size 18. If there are a total of 180 bananas in Philip's banana collection. How many groups are there?", "answer": 10.0, "metadata": {"ID": "chal-23", "Body": "The bananas in Philip's collection are organized into groups of size 18. If there are a total of 180 bananas in Philip's banana collection", "Question": "How many groups are there?", "Equation": "( 180.0 / 18.0 )", "Answer": 10.0, "Type": "Common-Division"}, "annotated_code": "group_size = 18\nn_total_bananas = 180\nn_groups = n_total_bananas / group_size\nanswer = n_groups"}
Loading