Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
MAGIC source code and results
Browse files Browse the repository at this point in the history
  • Loading branch information
arian-askari committed Jul 17, 2024
1 parent 1a40d7a commit d2da6f0
Show file tree
Hide file tree
Showing 3,389 changed files with 573,103 additions and 8 deletions.
The diff you're trying to view is too large. We only load the first 3000 changed files.
95 changes: 87 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,93 @@
# Project
# MAGIC: Generating Self-Correction Guideline for In-Context Text-to-SQL
[![](https://img.shields.io/badge/MAGIC-brightgreen)](./readme.md)
[![](https://img.shields.io/badge/SELF-Correction-blue)](./readme.md)

> This repo has been populated by an initial template to help get you started. Please
> make sure to update the content to build a great experience for community-building.
The repository of paper ["MAGIC: Generating Self-Correction Guideline for In-Context Text-to-SQL"](https://arxiv.org/abs/2406.12692).

As the maintainer of this project, please make a few updates:
If you use MAGIC's script, results, or trajectory data, please cite the following reference:

- Improving this README.MD file to provide a great experience
- Updating SUPPORT.MD with content about this project's support experience
- Understanding the security reporting process in SECURITY.MD
- Remove this section from the README
```bibtex
@article{askari2024magic,
title={MAGIC: Generating Self-Correction Guideline for In-Context Text-to-SQL},
author={Askari, Arian and Poelitz, Christian and Tang, Xinye},
journal={arXiv preprint arXiv:2406.12692},
year={2024}
}
```

## MAGIC

<img src="./figures/magic.svg">


## Instructions for reproducing experiments

The code for reproducing experiments is located in the ```magic``` folder, and the bash scripts for using each code are in the ```scripts``` folder. Run the project from the root directory after downloading it.

### Setup
First, install the libraries listed in the requirements.txt file. Use the following commands:

- ``````conda create --name magic python=3.10.14``````
- ``````conda activate magic``````
- ``````python3 -m pip install -r requirements.txt``````

Before running any experiments, download the BIRD and SPIDER datasets and store them in the data folder. For evaluation, use the official script released by each dataset. Set the api_key in ```configs/api_config.json``` file.

## Scripts
### 1. Multi Agent Feedback Generation with MAGIC
Run the feedback generation with the following command:

``````
./scripts/multi_agent_feedback_generation.sh
``````

Modify the path and parameters by editing the bash file for controlling number of threads and other params.

### 2. Guideline generation with manager of MAGIC
Run the guideline generation with the following command:

``````
./scripts/guideline_generation.sh
``````

You can modify the path and parameters by editing the bash file.

### 3. Self-correction with the guideline automatically generated by MAGIC

Run the self-correction with the following command:

``````
/scripts/self_correction_with_guideline.sh
``````


You can modify the path and parameters by editing the bash file.

## Results

The results of experiments and trajectory data of agents' interactions are stored in the ```results``` folder.
## Datasets

We conducted experiments with MAGIC using two open-source datasets: [BIRD](https://bird-bench.github.io/) (licensed under CC BY-SA 4.0) and [SPIDER](https://github.com/taoyds/spider) (licensed under Apache-2.0), making them suitable for research purposes.

### SPIDER dataset
The SPIDER dataset consists of 10,181 questions and 5,693 unique complex SQL queries sourced from 200 databases across 138 domains. Each domain contains multiple tables, and the dataset is split into training, development, and test sets with 8,659, 1,034, and 2,147 examples, respectively. The queries are categorized into four difficulty levels based on factors such as SQL keywords, nested subqueries, and the use of column selections and aggregations.

### BIRD dataset
The BIRD dataset comprises 12,751 question-SQL pairs extracted from 95 relatively large databases (33.4 GB) spanning 37 professional domains including blockchain and healthcare. BIRD introduces external knowledge to enhance the accuracy of SQL query generation, adding complexity to the task.

### Trajectory dataset (Ours)
The trajectory dataset, generated by MAGIC, is stored in the `./results/MAGIC-trajectory/` folder. This dataset captures interactions between agents aimed at correcting incorrect SQL queries and generating guidelines. Each trajectory data file includes the following fields:
- **prompt**: the input prompt given to the agent
- **response**: the agent's response to the prompt
- **caller_agent**: identifies which agent was called during each iteration, with values like `feedback_agent_call_{iteration_number}`, `correction_agent_call_{iteration_number}`, and `manager_revises_prompt_iteration_{iteration_number}`.

The trajectory data file names include "success-True" or "success-False," indicating whether the interaction between agents successfully led to self-correction. A "False" value indicates that the maximum number of iterations (5) was reached without achieving successful self-correction.


## Disclaimer
The purpose of this research project is solely focused on generating self-correction guidelines for incorrect SQL queries through agent interactions. The trajectory data released here is intended strictly for research purposes related to the text-to-SQL task and is not authorized for any other use.

## Contributing

Expand Down
8 changes: 8 additions & 0 deletions configs/api_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[
{"model": "gpt-4-turbo-0125-spot",
"api_key": "api_key",
"api_type": "api_type",
"base_url": "base_url",
"api_version": "api_version"
}
]
1 change: 1 addition & 0 deletions data/bird/dev.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download this from BIRD or SPIDER datasets (e.g., for BIRD it is https://bird-bench.github.io/ especially available on https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip and https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip)
1 change: 1 addition & 0 deletions data/bird/dev_databases/info.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download this from BIRD or SPIDER datasets (e.g., for BIRD it is https://bird-bench.github.io/ especially available on https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip and https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip)
1 change: 1 addition & 0 deletions data/bird/train.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download this from BIRD or SPIDER datasets (e.g., for BIRD it is https://bird-bench.github.io/ especially available on https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip and https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip)
1 change: 1 addition & 0 deletions data/bird/train_databases/info.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download this from BIRD or SPIDER datasets (e.g., for BIRD it is https://bird-bench.github.io/ especially available on https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip and https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip)
1 change: 1 addition & 0 deletions data/spider/dev.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download this from BIRD or SPIDER datasets (e.g., for BIRD it is https://bird-bench.github.io/ especially available on https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip and https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip)
1 change: 1 addition & 0 deletions data/spider/dev_databases/info.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download this from BIRD or SPIDER datasets (e.g., for BIRD it is https://bird-bench.github.io/ especially available on https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip and https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip)
1 change: 1 addition & 0 deletions data/spider/train.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download this from BIRD or SPIDER datasets (e.g., for BIRD it is https://bird-bench.github.io/ especially available on https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip and https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip)
1 change: 1 addition & 0 deletions data/spider/train_databases/info.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Download this from BIRD or SPIDER datasets (e.g., for BIRD it is https://bird-bench.github.io/ especially available on https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip and https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip)
1 change: 1 addition & 0 deletions figures/magic.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
212 changes: 212 additions & 0 deletions magic/guideline_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import pandas as pd
import argparse
import json
import glob
import time
import jsonlines
from utils.sql_extractor import extract_sql_query
from utils.llm_apis import getCompletionGPT4


def call_gpt4(
prompt,
model_key_name,
data_params,
prompt_type,
trajectory_log,
caller_agent="unknown",
):
if prompt_type == "string":
messages = [{"role": "user", "content": prompt}]
else:
messages = prompt

response = getCompletionGPT4(
messages, model_name=model_key_name, data_params=data_params, retry=False
)

# Log the interaction
trajectory_log.append(
{"prompt": prompt, "response": response, "caller_agent": caller_agent}
)

return response


def main(initial_pred_path, gold_df_path, trajectory_path, guideline_out_path):
# Load prediction object
with open(initial_pred_path, "r") as f:
pred_obj = json.load(f)

# Load gold dataframe
gold_df = pd.read_json(gold_df_path)

# Get the current time string
timestr = time.strftime("%Y%m%d-%H%M%S")

# Format the guideline output trajectory_path
guideline_out_path = guideline_out_path.format(timestr=timestr)

# Initialize an empty guideline
guideline_format = """
[number]. **[Reminder of mistake]**
- Question: "Question"
- **Incorrect SQL generated by me**: ```Incorrect corrected sql ```
- **Corrected SQL generated by me**: ```sql corrected sql ```
- **Negative and strict step-by-step ask-to-myself questions to prevent same mistake again**:
"""
current_guideline = """
"""

files = list(glob.glob(trajectory_path))
model_key_name = "gpt-4-turbo-0125-spot"
data_params = {
"top_p": 1.0,
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"max_tokens": 4096,
"stream": False,
"n": 1,
"temperature": 0.0,
}

guideline_materials = []

for index_, f in enumerate(files):
try:
file_name = f.split("\\")[-1]
question_id = file_name.split("trajectory-")[1].split("-success")[0]
succsess = "True" in file_name
if succsess:
question = gold_df.iloc[int(question_id)]["question"]
if "ratio" in question:
continue
trajectory_object = json.loads(open(f, "r").read())
latest_correct_sql = trajectory_object[-1]["response"]
latest_correct_sql = extract_sql_query(
latest_correct_sql, return_None=False
)

latest_feedback_that_worked = str(trajectory_object[-2]["response"])
initially_incorrect_predicted_sql = pred_obj[str(question_id)]

all_incorrect_sqls_by_correction_agents = []
for obj in trajectory_object[0:-1]:
if (
"caller_agent" in obj
and "correction_agent_call" in obj["caller_agent"]
):
failed_correction = extract_sql_query(
obj["response"], return_None=False
)
all_incorrect_sqls_by_correction_agents.append(
failed_correction
)

guideline_material = f"""
Question: {question}
Feedback: {latest_feedback_that_worked}
Incorrect sql 1: {initially_incorrect_predicted_sql}
"""
for incorrect_sql_index, correction_sql in enumerate(
all_incorrect_sqls_by_correction_agents
):
guideline_material += f"""
Incorrect sql {incorrect_sql_index + 2}: {correction_sql},
"""
guideline_material += f"""
Successfully Corrected SQL using the feedback: {latest_correct_sql}
"""
guideline_materials.append(guideline_material)

if len(guideline_materials) >= 10:
print("Updating guideline....")
user_prompt = f"""
# Guideline format:
{guideline_format}
# Guideline so far:
{current_guideline}
# Recent mistakes that must be aggregate to Guideline:
{guideline_materials}
# Updated Guideline (Return the entire guideline):
"""

prompt = [{"role": "user", "content": user_prompt}]
prompt_type = "message"
no_error = False
while not no_error:
try:
current_guideline = call_gpt4(
prompt,
model_key_name,
data_params,
prompt_type,
[],
caller_agent="unknown",
)
no_error = True
except Exception as e:
print(f"Retrying... error was: {e}")
pass

with jsonlines.open(guideline_out_path, mode="a") as jsonl_write:
obj = {f"guideline_iteration_{index_}": current_guideline}
jsonl_write.write(obj)
guideline_materials = []
print("Updated!")
except Exception as e:
pass
continue


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generating guideline by MAGIC.")
parser.add_argument(
"--initial_pred_path",
type=str,
default=r"./data/bird/train_initial_pred.json",
required=True,
help="Path to the initial prediction system file.",
)

parser.add_argument(
"--gold_df_path",
type=str,
default=r"./data/bird/train_df.json",
required=True,
help="Path to the gold data frame file.",
)

parser.add_argument(
"--trajectory_path",
type=str,
default=r"./src/results/MAGIC-trajectory/*",
required=True,
help="Path to the directory or files to process.",
)

parser.add_argument(
"--guideline_out_path",
type=str,
default=r"./src/results/MAGIC-Guideline/guideline_progress_per_batch.json",
required=True,
help="Template trajectory_path for the guideline output file.",
)

args = parser.parse_args()

initial_pred_path = args.initial_pred_path
gold_df_path = args.gold_df_path
trajectory_path = args.trajectory_path
guideline_out_path = args.guideline_out_path

print("Initial Prediction Path:", initial_pred_path)
print("Gold Data Frame Path:", gold_df_path)
print("Trajectory Path:", trajectory_path)
print("Guideline Output Path Template:", guideline_out_path)

main(initial_pred_path, gold_df_path, trajectory_path, guideline_out_path)
# python3 -u -m --initial_pred_path "./data/bird/train_initial_pred.json" --gold_df_path "./data/bird/train_df.json" --trajectory_path "./src/results/MAGIC-trajectory/*" --guideline_out_path "./src/results/MAGIC-Guideline/guideline_progress_per_batch.json"
Loading

0 comments on commit d2da6f0

Please sign in to comment.