diff --git a/src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/trec_dl_relevance_judges.ipynb b/src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/trec_dl_relevance_judges.ipynb index 752f1785a..97c57ec44 100644 --- a/src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/trec_dl_relevance_judges.ipynb +++ b/src/benchmark/trulens/benchmark/benchmark_frameworks/experiments/trec_dl_relevance_judges.ipynb @@ -59,6 +59,143 @@ "\"\"\"" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Human annotation quality ananlysis and check the agreement with generated scores from `scoreddocs`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ir_datasets\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "\n", + "\n", + "def scoreddocs_qrels_confusion_matrix(\n", + " dataset_path=\"msmarco-passage-v2/trec-dl-2022/judged\",\n", + " aggregate=False,\n", + " normalize=False,\n", + "):\n", + " # Load the dataset\n", + " dataset = ir_datasets.load(dataset_path)\n", + " qrels = dataset.qrels_dict()\n", + " scoreddocs = list(dataset.scoreddocs_iter())\n", + "\n", + " # Prepare a DataFrame for analysis\n", + " data = []\n", + " for scored_doc in scoreddocs:\n", + " query_id = scored_doc.query_id\n", + " doc_id = scored_doc.doc_id\n", + " score = scored_doc.score\n", + " qrel_score = qrels.get(query_id, {}).get(doc_id, None)\n", + " if qrel_score is not None:\n", + " data.append({\n", + " \"query_id\": query_id,\n", + " \"doc_id\": doc_id,\n", + " \"score\": score,\n", + " \"qrel_score\": qrel_score,\n", + " })\n", + "\n", + " df = pd.DataFrame(data)\n", + "\n", + " # Analyze per query or aggregate\n", + " all_data = []\n", + " confusion_matrices = {}\n", + " for query_id, group in df.groupby(\"query_id\"):\n", + " min_score = group[\"score\"].min()\n", + " max_score = group[\"score\"].max()\n", + " interval_size = (max_score - min_score) / 4\n", + " intervals = [\n", + " min_score + i * interval_size for i in range(5)\n", + " ] # 4 intervals\n", + "\n", + " # Assign each passage to an interval\n", + " group[\"interval\"] = pd.cut(\n", + " group[\"score\"],\n", + " bins=intervals,\n", + " include_lowest=True,\n", + " labels=[0, 1, 2, 3],\n", + " )\n", + "\n", + " if aggregate:\n", + " # Append all data for aggregation\n", + " all_data.append(group)\n", + " else:\n", + " # Create confusion matrix for each query\n", + " confusion_matrix = pd.crosstab(\n", + " group[\"qrel_score\"], group[\"interval\"]\n", + " )\n", + "\n", + " # Normalize across rows (qrels scores)\n", + " if normalize:\n", + " confusion_matrix = confusion_matrix.div(\n", + " confusion_matrix.sum(axis=1), axis=0\n", + " )\n", + "\n", + " confusion_matrices[query_id] = confusion_matrix\n", + "\n", + " # Visualize the confusion matrix\n", + " plt.figure(figsize=(8, 6))\n", + " sns.heatmap(\n", + " confusion_matrix, annot=True, fmt=\".2f\", cmap=\"Blues\", cbar=True\n", + " )\n", + " plt.title(\n", + " f\"Confusion Matrix for Query {query_id} (Normalized by Qrels)\"\n", + " )\n", + " plt.xlabel(\"Scoreddocs Intervals\")\n", + " plt.ylabel(\"Qrels Scores\")\n", + " plt.show()\n", + "\n", + " if aggregate:\n", + " # Combine all groups into a single DataFrame\n", + " aggregated_df = pd.concat(all_data, ignore_index=True)\n", + "\n", + " # Create an aggregate confusion matrix\n", + " aggregate_confusion_matrix = pd.crosstab(\n", + " aggregated_df[\"qrel_score\"], aggregated_df[\"interval\"]\n", + " )\n", + "\n", + " # Normalize across rows (qrels scores)\n", + " if normalize:\n", + " aggregate_confusion_matrix = aggregate_confusion_matrix.div(\n", + " aggregate_confusion_matrix.sum(axis=1), axis=0\n", + " )\n", + "\n", + " # Visualize the aggregate confusion matrix\n", + " plt.figure(figsize=(10, 8))\n", + " sns.heatmap(\n", + " aggregate_confusion_matrix,\n", + " annot=True,\n", + " fmt=\".2f\",\n", + " cmap=\"Blues\",\n", + " cbar=True,\n", + " )\n", + " plt.title(\n", + " \"Aggregate Confusion Matrix Across All Queries (Normalized by Qrels)\"\n", + " )\n", + " plt.xlabel(\"Scoreddocs Intervals\")\n", + " plt.ylabel(\"Qrels Scores\")\n", + " plt.show()\n", + "\n", + " return aggregate_confusion_matrix\n", + "\n", + " return confusion_matrices\n", + "\n", + "\n", + "# Run the analysis for individual queries\n", + "# confusion_matrices = scoreddocs_qrels_confusion_matrix(aggregate=False)\n", + "\n", + "# Run the analysis for aggregate across all queries\n", + "aggregate_confusion_matrix = scoreddocs_qrels_confusion_matrix(aggregate=True)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -162,7 +299,6 @@ "# trec_2022_samples = list(generate_trec_dl_benchmark(max_samples_per_query_per_score=3, dataset_path=\"msmarco-passage-v2/trec-dl-2022/judged\"))\n", "# trec_combined = trec_2021_samples + trec_2022_samples\n", "\n", - "import pandas as pd\n", "\n", "# trec_combined_df = pd.DataFrame(trec_combined)\n", "# trec_combined_df.to_csv(\"trec_dl_2021_2022_benchmark.csv\", index=False)\n",