From 0960c45243147cd348391e266020f21d8b1e6570 Mon Sep 17 00:00:00 2001 From: Joshua We <80349780+joshuawe@users.noreply.github.com> Date: Wed, 6 Dec 2023 21:28:19 +0100 Subject: [PATCH] add multiclass prob histogram (#22) --- notebooks/multiclass_classification.ipynb | 131 ++++++++++++++++++++++ plotsandgraphs/__init__.py | 1 + plotsandgraphs/multiclass_classifier.py | 73 ++++++++++++ 3 files changed, 205 insertions(+) create mode 100644 notebooks/multiclass_classification.ipynb create mode 100644 plotsandgraphs/multiclass_classifier.py diff --git a/notebooks/multiclass_classification.ipynb b/notebooks/multiclass_classification.ipynb new file mode 100644 index 0000000..194af92 --- /dev/null +++ b/notebooks/multiclass_classification.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multiclass Classification\n", + "\n", + "This notebook explores multiclass classification" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import plotsandgraphs as pandg\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create some dummy data" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[9.09443337e-01, 9.03393597e-01, 8.31210285e-01],\n", + " [2.21902873e-02, 2.76896844e-01, 1.17129033e-02],\n", + " [3.17239348e-01, 9.98593024e-01, 1.63289150e-02],\n", + " ...,\n", + " [2.58836187e-02, 2.28105168e-01, 5.37207598e-01],\n", + " [2.17134178e-01, 5.08693900e-01, 2.65985367e-01],\n", + " [2.86897406e-15, 7.97772016e-01, 3.77128950e-02]])" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# True labels\n", + "y_true = np.random.choice([0, 1, 2], p=[0.3, 0.5, 0.2], size=1000)\n", + "# one hot encoding\n", + "y_true_one_hot = np.eye(3)[y_true] \n", + "\n", + "# Predicted labels\n", + "y_pred = np.ones(y_true_one_hot.shape)\n", + "\n", + "a0, b0 = [0.1, 0.6, 0.3, 0.4], [0.4, 1.2, 0.8, 1]\n", + "a1, b1 = [0.9, 0.8, 0.9, 1.2], [0.4, 0.1, 0.5, 0.3]\n", + "# iterate through all the columns/labels\n", + "for i in range(y_pred.shape[1]):\n", + " y = y_pred[:, i]\n", + " y_t = y_true_one_hot[:, i]\n", + " y[y_t==0] = np.random.beta(a0[i], b0[i], size=y[y_t==0].shape)\n", + " y[y_t==1] = np.random.beta(a1[i], b1[i], size=y[y_t==1].shape)\n", + "y_pred" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Histogram of probabilities" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = pandg.multiclass_classifier.plot_y_prob_histogram(y_true_one_hot, y_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "DAISE", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/plotsandgraphs/__init__.py b/plotsandgraphs/__init__.py index 1d36456..4f9626e 100644 --- a/plotsandgraphs/__init__.py +++ b/plotsandgraphs/__init__.py @@ -1,2 +1,3 @@ from . import binary_classifier from . import compare_distributions +from . import multiclass_classifier diff --git a/plotsandgraphs/multiclass_classifier.py b/plotsandgraphs/multiclass_classifier.py new file mode 100644 index 0000000..b69d532 --- /dev/null +++ b/plotsandgraphs/multiclass_classifier.py @@ -0,0 +1,73 @@ +from pathlib import Path +from typing import Optional +import matplotlib.pyplot as plt +from matplotlib.colors import to_rgba +from matplotlib.figure import Figure +import numpy as np +import pandas as pd +from sklearn.metrics import ( + confusion_matrix, + classification_report, + ConfusionMatrixDisplay, + roc_curve, + auc, + accuracy_score, + precision_recall_curve, +) +from sklearn.calibration import calibration_curve +from sklearn.utils import resample +from tqdm import tqdm + + +def plot_y_prob_histogram(y_true: np.ndarray, y_prob: Optional[np.ndarray] = None, save_fig_path=None) -> Figure: + + plot_len = np.ceil(np.sqrt(y_true.shape[-1])).astype(int) # Number of plots in a row/column + fig, axes = plt.subplots(nrows=plot_len, ncols=plot_len, figsize=(plot_len*4+1, plot_len*4), sharey=True) + alpha = 0.6 + plt.suptitle("Predicted probability histogram") + + for i, ax in enumerate(axes.flat): + if i >= y_true.shape[-1]: + ax.axis("off") + continue + + if y_prob is not None: + y_true_i = y_true[:, i] + y_prob_i = y_prob[:, i] + ax.hist(y_prob_i[y_true_i==0], + bins=10, + label="$\\hat{y} = 0$", + alpha=alpha, + edgecolor="midnightblue", + linewidth=2, + rwidth=1,) + ax.hist(y_prob_i[y_true_i==1], + bins=10, + label="$\\hat{y} = 1$", + alpha=alpha, + edgecolor="midnightblue", + linewidth=2, + rwidth=1,) + ax.set_title(f"Class {i}") + ax.set_xlim((-0.005, 1.0)) + # if subplot in first column + if i % plot_len == 0: + ax.set_ylabel("Count [-]") + # if subplot in last row + if i >= plot_len*(plot_len-1): + ax.set_xlabel("Predicted probability [-]") + # ax.spines[:].set_visible(False) + ax.grid(True, linestyle="-", linewidth=0.5, color="grey", alpha=0.5) + ax.set_xticks(np.arange(0, 1.1, 0.2)) + # only first subplot should have legends + if i == 0: + ax.legend() + + plt.tight_layout() + + # save plot + if save_fig_path is not None: + path = Path(save_fig_path) + path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_fig_path, bbox_inches="tight") + return fig