From ef714c09c72c7f8a4bd35199a6df1b7321b5334d Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Wed, 26 Jun 2024 14:08:17 +0800 Subject: [PATCH] updating openai prompt --- prompts/prompt_openai.json | 6 +----- query_generators/openai.py | 6 ++++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/prompts/prompt_openai.json b/prompts/prompt_openai.json index 0e096b9..725a100 100644 --- a/prompts/prompt_openai.json +++ b/prompts/prompt_openai.json @@ -1,14 +1,10 @@ [ { "role": "system", - "content": "Your role is to convert text question to PostgreSQL queries, given a database schema." + "content": "Your role is to convert a user question to a PostgreSQL query, given a database schema." }, { "role": "user", "content": "Generate a SQL query that answers the question `{user_question}`.\n{instructions}{glossary}\nThis query will run on a database whose schema is represented in this string:\n{table_metadata_string}\n{k_shot_prompt}\nReturn only the SQL query, and nothing else." - }, - { - "role": "assistant", - "content": "Given the database schema, here is the SQL query that answers `{user_question}`:\n```sql" } ] diff --git a/query_generators/openai.py b/query_generators/openai.py index 80c2972..73634d6 100644 --- a/query_generators/openai.py +++ b/query_generators/openai.py @@ -154,7 +154,8 @@ def generate_query( try: sys_prompt = chat_prompt[0]["content"] user_prompt = chat_prompt[1]["content"] - assistant_prompt = chat_prompt[2]["content"] + if len(chat_prompt) == 3: + assistant_prompt = chat_prompt[2]["content"] except: raise ValueError("Invalid prompt file. Please use prompt_openai.md") user_prompt = user_prompt.format( @@ -171,7 +172,8 @@ def generate_query( messages = [] messages.append({"role": "system", "content": sys_prompt}) messages.append({"role": "user", "content": user_prompt}) - messages.append({"role": "assistant", "content": assistant_prompt}) + if len(chat_prompt) == 3: + messages.append({"role": "assistant", "content": assistant_prompt}) function_to_run = None package = None