Skip to content

Commit

Permalink
add agreement analysis with scoreddocs
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dhuang committed Nov 25, 2024
1 parent cf5979a commit 17941b0
Showing 1 changed file with 137 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,143 @@
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Human annotation quality ananlysis and check the agreement with generated scores from `scoreddocs`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ir_datasets\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"\n",
"\n",
"def scoreddocs_qrels_confusion_matrix(\n",
" dataset_path=\"msmarco-passage-v2/trec-dl-2022/judged\",\n",
" aggregate=False,\n",
" normalize=False,\n",
"):\n",
" # Load the dataset\n",
" dataset = ir_datasets.load(dataset_path)\n",
" qrels = dataset.qrels_dict()\n",
" scoreddocs = list(dataset.scoreddocs_iter())\n",
"\n",
" # Prepare a DataFrame for analysis\n",
" data = []\n",
" for scored_doc in scoreddocs:\n",
" query_id = scored_doc.query_id\n",
" doc_id = scored_doc.doc_id\n",
" score = scored_doc.score\n",
" qrel_score = qrels.get(query_id, {}).get(doc_id, None)\n",
" if qrel_score is not None:\n",
" data.append({\n",
" \"query_id\": query_id,\n",
" \"doc_id\": doc_id,\n",
" \"score\": score,\n",
" \"qrel_score\": qrel_score,\n",
" })\n",
"\n",
" df = pd.DataFrame(data)\n",
"\n",
" # Analyze per query or aggregate\n",
" all_data = []\n",
" confusion_matrices = {}\n",
" for query_id, group in df.groupby(\"query_id\"):\n",
" min_score = group[\"score\"].min()\n",
" max_score = group[\"score\"].max()\n",
" interval_size = (max_score - min_score) / 4\n",
" intervals = [\n",
" min_score + i * interval_size for i in range(5)\n",
" ] # 4 intervals\n",
"\n",
" # Assign each passage to an interval\n",
" group[\"interval\"] = pd.cut(\n",
" group[\"score\"],\n",
" bins=intervals,\n",
" include_lowest=True,\n",
" labels=[0, 1, 2, 3],\n",
" )\n",
"\n",
" if aggregate:\n",
" # Append all data for aggregation\n",
" all_data.append(group)\n",
" else:\n",
" # Create confusion matrix for each query\n",
" confusion_matrix = pd.crosstab(\n",
" group[\"qrel_score\"], group[\"interval\"]\n",
" )\n",
"\n",
" # Normalize across rows (qrels scores)\n",
" if normalize:\n",
" confusion_matrix = confusion_matrix.div(\n",
" confusion_matrix.sum(axis=1), axis=0\n",
" )\n",
"\n",
" confusion_matrices[query_id] = confusion_matrix\n",
"\n",
" # Visualize the confusion matrix\n",
" plt.figure(figsize=(8, 6))\n",
" sns.heatmap(\n",
" confusion_matrix, annot=True, fmt=\".2f\", cmap=\"Blues\", cbar=True\n",
" )\n",
" plt.title(\n",
" f\"Confusion Matrix for Query {query_id} (Normalized by Qrels)\"\n",
" )\n",
" plt.xlabel(\"Scoreddocs Intervals\")\n",
" plt.ylabel(\"Qrels Scores\")\n",
" plt.show()\n",
"\n",
" if aggregate:\n",
" # Combine all groups into a single DataFrame\n",
" aggregated_df = pd.concat(all_data, ignore_index=True)\n",
"\n",
" # Create an aggregate confusion matrix\n",
" aggregate_confusion_matrix = pd.crosstab(\n",
" aggregated_df[\"qrel_score\"], aggregated_df[\"interval\"]\n",
" )\n",
"\n",
" # Normalize across rows (qrels scores)\n",
" if normalize:\n",
" aggregate_confusion_matrix = aggregate_confusion_matrix.div(\n",
" aggregate_confusion_matrix.sum(axis=1), axis=0\n",
" )\n",
"\n",
" # Visualize the aggregate confusion matrix\n",
" plt.figure(figsize=(10, 8))\n",
" sns.heatmap(\n",
" aggregate_confusion_matrix,\n",
" annot=True,\n",
" fmt=\".2f\",\n",
" cmap=\"Blues\",\n",
" cbar=True,\n",
" )\n",
" plt.title(\n",
" \"Aggregate Confusion Matrix Across All Queries (Normalized by Qrels)\"\n",
" )\n",
" plt.xlabel(\"Scoreddocs Intervals\")\n",
" plt.ylabel(\"Qrels Scores\")\n",
" plt.show()\n",
"\n",
" return aggregate_confusion_matrix\n",
"\n",
" return confusion_matrices\n",
"\n",
"\n",
"# Run the analysis for individual queries\n",
"# confusion_matrices = scoreddocs_qrels_confusion_matrix(aggregate=False)\n",
"\n",
"# Run the analysis for aggregate across all queries\n",
"aggregate_confusion_matrix = scoreddocs_qrels_confusion_matrix(aggregate=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -162,7 +299,6 @@
"# trec_2022_samples = list(generate_trec_dl_benchmark(max_samples_per_query_per_score=3, dataset_path=\"msmarco-passage-v2/trec-dl-2022/judged\"))\n",
"# trec_combined = trec_2021_samples + trec_2022_samples\n",
"\n",
"import pandas as pd\n",
"\n",
"# trec_combined_df = pd.DataFrame(trec_combined)\n",
"# trec_combined_df.to_csv(\"trec_dl_2021_2022_benchmark.csv\", index=False)\n",
Expand Down

0 comments on commit 17941b0

Please sign in to comment.