-
Notifications
You must be signed in to change notification settings - Fork 602
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1211 from zc277584121/master
evaluate fiqa openai
- Loading branch information
Showing
1 changed file
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,319 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"# ! python -m pip install openai beir ragas==0.0.17" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"1706\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import json\n", | ||
"import pandas as pd\n", | ||
"import os\n", | ||
"from tqdm import tqdm\n", | ||
"from datasets import Dataset\n", | ||
"from beir import util\n", | ||
"\n", | ||
"\n", | ||
"def prepare_fiqa_without_answer(knowledge_path):\n", | ||
" dataset_name = \"fiqa\"\n", | ||
"\n", | ||
" if not os.path.exists(os.path.join(knowledge_path, f'{dataset_name}.zip')):\n", | ||
" url = (\n", | ||
" \"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip\".format(\n", | ||
" dataset_name\n", | ||
" )\n", | ||
" )\n", | ||
" util.download_and_unzip(url, knowledge_path)\n", | ||
"\n", | ||
" data_path = os.path.join(knowledge_path, 'fiqa')\n", | ||
" with open(os.path.join(data_path, \"corpus.jsonl\")) as f:\n", | ||
" cs = [pd.Series(json.loads(l)) for l in f.readlines()]\n", | ||
"\n", | ||
" corpus_df = pd.DataFrame(cs)\n", | ||
"\n", | ||
" corpus_df = corpus_df.rename(columns={\"_id\": \"corpus-id\", \"text\": \"ground_truth\"})\n", | ||
" corpus_df = corpus_df.drop(columns=[\"title\", \"metadata\"])\n", | ||
" corpus_df[\"corpus-id\"] = corpus_df[\"corpus-id\"].astype(int)\n", | ||
" corpus_df.head()\n", | ||
"\n", | ||
" with open(os.path.join(data_path, \"queries.jsonl\")) as f:\n", | ||
" qs = [pd.Series(json.loads(l)) for l in f.readlines()]\n", | ||
"\n", | ||
" queries_df = pd.DataFrame(qs)\n", | ||
" queries_df = queries_df.rename(columns={\"_id\": \"query-id\", \"text\": \"question\"})\n", | ||
" queries_df = queries_df.drop(columns=[\"metadata\"])\n", | ||
" queries_df[\"query-id\"] = queries_df[\"query-id\"].astype(int)\n", | ||
" queries_df.head()\n", | ||
"\n", | ||
" splits = [\"dev\", \"test\", \"train\"]\n", | ||
" split_df = {}\n", | ||
" for s in splits:\n", | ||
" split_df[s] = pd.read_csv(os.path.join(data_path, f\"qrels/{s}.tsv\"), sep=\"\\t\").drop(\n", | ||
" columns=[\"score\"]\n", | ||
" )\n", | ||
"\n", | ||
" final_split_df = {}\n", | ||
" for split in split_df:\n", | ||
" df = queries_df.merge(split_df[split], on=\"query-id\")\n", | ||
" df = df.merge(corpus_df, on=\"corpus-id\")\n", | ||
" df = df.drop(columns=[\"corpus-id\"])\n", | ||
" grouped = df.groupby(\"query-id\").apply(\n", | ||
" lambda x: pd.Series(\n", | ||
" {\n", | ||
" \"question\": x[\"question\"].sample().values[0],\n", | ||
" \"ground_truths\": x[\"ground_truth\"].tolist(),\n", | ||
" }\n", | ||
" )\n", | ||
" )\n", | ||
"\n", | ||
" grouped = grouped.reset_index()\n", | ||
" grouped = grouped.drop(columns=\"query-id\")\n", | ||
" final_split_df[split] = grouped\n", | ||
"\n", | ||
" return final_split_df\n", | ||
"\n", | ||
"\n", | ||
"knowledge_datas_path = './knowledge_datas'\n", | ||
"fiqa_path = os.path.join(knowledge_datas_path, 'fiqa_doc.txt')\n", | ||
"\n", | ||
"if not os.path.exists(knowledge_datas_path):\n", | ||
" os.mkdir(knowledge_datas_path)\n", | ||
"contexts_list = []\n", | ||
"answer_list = []\n", | ||
"\n", | ||
"final_split_df = prepare_fiqa_without_answer(knowledge_datas_path)\n", | ||
"\n", | ||
"docs = []\n", | ||
"\n", | ||
"split = 'test'\n", | ||
"for ds in final_split_df[split][\"ground_truths\"]:\n", | ||
" docs.extend([d for d in ds])\n", | ||
"print(len(docs))\n", | ||
"\n", | ||
"docs_str = '\\n'.join(docs)\n", | ||
"with open(fiqa_path, 'w') as f:\n", | ||
" f.write(docs_str)\n", | ||
"\n", | ||
"split = 'test'\n", | ||
"question_list = final_split_df[split][\"question\"].to_list()\n", | ||
"ground_truth_list = final_split_df[split][\"ground_truths\"].to_list()" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"outputs": [], | ||
"source": [ | ||
"import time\n", | ||
"from openai import OpenAI\n", | ||
"\n", | ||
"client = OpenAI()\n", | ||
"\n", | ||
"# Set OPENAI_API_KEY in your environment value\n", | ||
"client.api_key = os.getenv('OPENAI_API_KEY')\n", | ||
"\n", | ||
"\n", | ||
"class OpenAITimeoutException(Exception):\n", | ||
" pass\n", | ||
"\n", | ||
"\n", | ||
"def get_content_from_retrieved_message(message):\n", | ||
" # Extract the message content\n", | ||
" message_content = message.content[0].text\n", | ||
" annotations = message_content.annotations\n", | ||
" contexts = []\n", | ||
" for annotation in annotations:\n", | ||
" message_content.value = message_content.value.replace(annotation.text, f'')\n", | ||
" if (file_citation := getattr(annotation, 'file_citation', None)):\n", | ||
" contexts.append(file_citation.quote)\n", | ||
" if len(contexts) == 0:\n", | ||
" contexts = ['empty context.']\n", | ||
" return message_content.value, contexts\n", | ||
"\n", | ||
"\n", | ||
"def try_get_answer_contexts(assistant_id, question, timeout_seconds=120):\n", | ||
" thread = client.beta.threads.create(\n", | ||
" messages=[\n", | ||
" {\n", | ||
" \"role\": \"user\",\n", | ||
" \"content\": question,\n", | ||
" }\n", | ||
" ]\n", | ||
" )\n", | ||
" thread_id = thread.id\n", | ||
" run = client.beta.threads.runs.create(\n", | ||
" thread_id=thread_id,\n", | ||
" assistant_id=assistant_id,\n", | ||
" )\n", | ||
" start_time = time.time()\n", | ||
" while True:\n", | ||
" elapsed_time = time.time() - start_time\n", | ||
" if elapsed_time > timeout_seconds:\n", | ||
" raise Exception(\"OpenAI retrieving answer Timeout!\")\n", | ||
"\n", | ||
" run = client.beta.threads.runs.retrieve(\n", | ||
" thread_id=thread_id,\n", | ||
" run_id=run.id\n", | ||
" )\n", | ||
" if run.status == 'completed':\n", | ||
" break\n", | ||
" messages = client.beta.threads.messages.list(\n", | ||
" thread_id=thread_id\n", | ||
" )\n", | ||
" assert len(messages.data) > 1\n", | ||
" res, contexts = get_content_from_retrieved_message(messages.data[0])\n", | ||
" response = client.beta.threads.delete(thread_id)\n", | ||
" assert response.deleted is True\n", | ||
" return contexts, res\n", | ||
"\n", | ||
"\n", | ||
"def get_answer_contexts_from_assistant(question, assistant_id, timeout_seconds=120, retry_num=6):\n", | ||
" res = 'failed. please retry.'\n", | ||
" contexts = ['failed. please retry.']\n", | ||
" try:\n", | ||
" for _ in range(retry_num):\n", | ||
" try:\n", | ||
" contexts, res = try_get_answer_contexts(assistant_id, question, timeout_seconds)\n", | ||
" break\n", | ||
" except OpenAITimeoutException as e:\n", | ||
" print('OpenAI retrieving answer Timeout, retry...')\n", | ||
" continue\n", | ||
" except Exception as e:\n", | ||
" print(e)\n", | ||
" return res, contexts" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
" 0%| | 0/648 [03:45<?, ?it/s]\n", | ||
"\n", | ||
"KeyboardInterrupt\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"file = client.files.create(\n", | ||
" file=open(fiqa_path, \"rb\"),\n", | ||
" purpose='assistants'\n", | ||
")\n", | ||
"\n", | ||
"# Add the file to the assistant\n", | ||
"assistant = client.beta.assistants.create(\n", | ||
" instructions=\"You are a customer support chatbot. You must use your retrieval tool to retrieve relevant knowledge to best respond to customer queries.\",\n", | ||
" model=\"gpt-4-1106-preview\",\n", | ||
" tools=[{\"type\": \"retrieval\"}],\n", | ||
" file_ids=[file.id]\n", | ||
")\n", | ||
"\n", | ||
"for question in tqdm(question_list):\n", | ||
" answer, contexts = get_answer_contexts_from_assistant(question, assistant.id)\n", | ||
" # print(f'answer = {answer}')\n", | ||
" # print(f'contexts = {contexts}')\n", | ||
" # print('=' * 80)\n", | ||
" answer_list.append(answer)\n", | ||
" contexts_list.append(contexts)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"from ragas import evaluate\n", | ||
"from ragas.metrics import answer_relevancy, faithfulness, context_recall, context_precision, answer_similarity\n", | ||
"\n", | ||
"ds = Dataset.from_dict({\"question\": question_list,\n", | ||
" \"contexts\": contexts_list,\n", | ||
" \"answer\": answer_list,\n", | ||
" \"ground_truths\": ground_truth_list})\n", | ||
"\n", | ||
"result = evaluate(\n", | ||
" ds,\n", | ||
" metrics=[\n", | ||
" context_precision,\n", | ||
" # context_recall,\n", | ||
" # faithfulness,\n", | ||
" # answer_relevancy,\n", | ||
" # answer_similarity,\n", | ||
" # answer_correctness,\n", | ||
" ],\n", | ||
"\n", | ||
")\n", | ||
"print(result)" | ||
], | ||
"metadata": { | ||
"collapsed": false, | ||
"pycharm": { | ||
"name": "#%%\n" | ||
} | ||
} | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |