diff --git a/README.md b/README.md index c9cfc77..29b4c64 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,9 @@ As of *v0.2*, plots are not yet tested in the travis build. ## Release History +* 0.2.7 + * ADD: Function ``plot_pie`` to plot counter as a pie chart. + * ADD: Function ``plot_confusion_matrix`` to calculate and plot a confusion matrix. * 0.2.6 * CHANGE: Use of ``Optional`` keyword from ``typing`` for optional arguments. * ADD: Function ``plot_pca_explained_variance_ratio`` to plot the explained variance of PCA. diff --git a/bff/plot/__init__.py b/bff/plot/__init__.py index 8aeedfd..2bf7213 100644 --- a/bff/plot/__init__.py +++ b/bff/plot/__init__.py @@ -1,6 +1,7 @@ """Plot module of bff.""" from .plot import ( + plot_confusion_matrix, plot_correlation, plot_counter, plot_history, @@ -14,6 +15,7 @@ # Public object of the module. __all__ = [ + 'plot_confusion_matrix', 'plot_correlation', 'plot_counter', 'plot_history', diff --git a/bff/plot/plot.py b/bff/plot/plot.py index d3483b8..8390c82 100644 --- a/bff/plot/plot.py +++ b/bff/plot/plot.py @@ -5,7 +5,7 @@ """ import logging from collections import Counter -from typing import Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union import matplotlib as mpl import matplotlib.lines as mlines import matplotlib.patches as mpatches @@ -65,6 +65,135 @@ def callback(axes): return ax +def plot_confusion_matrix(y_true: Union[np.array, pd.Series, Sequence], + y_pred: Union[np.array, pd.Series, Sequence], + labels_filter: Optional[Union[np.array, Sequence]] = None, + ticklabels: Any = 'auto', + sample_weight: Optional[str] = None, + normalize: Optional[str] = None, + stats: Optional[str] = None, + title: str = 'Confusion matrix', + ax: Optional[plt.axes] = None, + rotation_xticks: Union[float, None] = 90, + rotation_yticks: Optional[float] = None, + figsize: Tuple[int, int] = (13, 10), + dpi: int = 80, + style: str = 'white') -> plt.axes: + """ + Plot the confusion matrix. + + The confusion matrix is computed in the function. + + Parameters + ---------- + y_true : np.array, pd.Series or Sequence + Actual values. + y_pred : np.array, pd.Series or Sequence + Predicted values by the model. + labels_filter : array-like of shape (n_classes,), default None + List of labels to index the matrix. This may be used to reorder or + select a subset of labels. If `None` is given, those that appear at + least once in `y_true` or `y_pred` are used in sorted order. + ticklabels : 'auto', bool, list-like, or int, default 'auto' + If True, plot the column names of the dataframe. If False, don’t plot the column names. + If list-like, plot these alternate labels as the xticklabels. If an integer, + use the column names but plot only every n label. If “auto”, + try to densely plot non-overlapping labels. + sample_weight : array-like of shape (n_samples,), optional + Sample weights. + normalize : str {'true', 'pred', 'all'}, optional + Normalizes confusion matrix over the true (rows), predicted (columns) + conditions or all the population. If None, confusion matrix will not be + normalized. + stats : str {'accuracy', 'precision', 'recall', 'f1-score'}, optional + Calculate and display the wanted statistic below the figure. + title : str, default 'Confusion matrix' + Title for the plot (axis level). + ax : plt.axes, optional + Axes from matplotlib, if None, new figure and axes will be created. + rotation_xticks : float or None, default 90 + Rotation of x ticks if any. + rotation_yticks : float, optional + Rotation of x ticks if any. + Set to 90 to put them vertically. + figsize : Tuple[int, int], default (13, 10) + Size of the figure to plot. + dpi : int, default 80 + Resolution of the figure. + style : str, default 'white' + Style to use for seaborn.axes_style. + The style is use only in this context and not applied globally. + + Returns + ------- + plt.axes + Axes returned by the `plt.subplots` function. + + Examples + -------- + >>> y_true = ['dog', 'cat', 'bird', 'cat', 'dog', 'dog'] + >>> y_pred = ['cat', 'cat', 'bird', 'dog', 'bird', 'dog'] + >>> plot_confusion_matrix(y_true, y_pred, stats='accuracy') + """ + bff.fancy._check_sklearn_support('plot_confusion_matrix') + from sklearn.metrics import classification_report, confusion_matrix + + # Compute the confusion matrix. + cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight, + labels=labels_filter, normalize=normalize) + + with sns.axes_style(style): + if ax is None: + __, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi) + + if ticklabels in (True, 'auto') and labels_filter is None: + ticklabels = sorted(set(list(y_true) + list(y_pred))) + + # Draw the heatmap with the mask and correct aspect ratio. + sns.heatmap(cm, cmap=plt.cm.Blues, ax=ax, annot=True, square=True, + linewidths=0.5, cbar_kws={"shrink": 0.75}, + xticklabels=ticklabels, yticklabels=ticklabels) + + if stats: + report = classification_report(y_true, y_pred, labels=labels_filter, + sample_weight=sample_weight, output_dict=True) + if stats == 'accuracy': + ax.text(1.05, 0.05, f'{report[stats]:.2f}', horizontalalignment='left', + verticalalignment='center', transform=ax.transAxes) + else: + if ticklabels in (None, 'auto'): + ticklabels = report.keys()[:-3] + + # Depending on the metric, there is one value by class. + # For each class, print the value of the metric. + for i, label in enumerate(ticklabels): + # Label might be an integer, cast to be sure. + label = str(label) + + if stats in report[label].keys(): + ax.text(1.05, 0.05 - (0.03 * i), + f'{label}: {report[label][stats]:.2f}', + horizontalalignment='left', + verticalalignment='center', transform=ax.transAxes) + else: + LOGGER.error(f'Wrong key {stats}, possible values: ' + f'{list(report[label].keys())}.') + # Print the metric used. + if stats in report.keys() or stats in report[str(ticklabels[0])].keys(): + ax.text(1.05, 0.08, f'{stats.capitalize()}', fontweight='bold', + horizontalalignment='left', + verticalalignment='center', transform=ax.transAxes) + + ax.set_xlabel('Predicted label', fontsize=12) + ax.set_ylabel('True label', fontsize=12) + ax.set_title(title, fontsize=14) + # Style of the ticks. + plt.xticks(fontsize=12, alpha=1, rotation=rotation_xticks) + plt.yticks(fontsize=12, alpha=1, rotation=rotation_yticks) + + return ax + + def plot_correlation(df: pd.DataFrame, already_computed: bool = False, method: str = 'pearson', diff --git a/tests/baseline/test_plot_confusion_matrix.png b/tests/baseline/test_plot_confusion_matrix.png new file mode 100644 index 0000000..1832e91 Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_labels_filter.png b/tests/baseline/test_plot_confusion_matrix_labels_filter.png new file mode 100644 index 0000000..43f87e2 Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_labels_filter.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_normalize.png b/tests/baseline/test_plot_confusion_matrix_normalize.png new file mode 100644 index 0000000..c09387a Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_normalize.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_sample_weight.png b/tests/baseline/test_plot_confusion_matrix_sample_weight.png new file mode 100644 index 0000000..c5ce9bc Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_sample_weight.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_stats_acc.png b/tests/baseline/test_plot_confusion_matrix_stats_acc.png new file mode 100644 index 0000000..1a6749e Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_stats_acc.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_stats_error.png b/tests/baseline/test_plot_confusion_matrix_stats_error.png new file mode 100644 index 0000000..e341165 Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_stats_error.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_stats_fscore.png b/tests/baseline/test_plot_confusion_matrix_stats_fscore.png new file mode 100644 index 0000000..433047e Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_stats_fscore.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_stats_prec.png b/tests/baseline/test_plot_confusion_matrix_stats_prec.png new file mode 100644 index 0000000..f809a97 Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_stats_prec.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_ticklabels_cat.png b/tests/baseline/test_plot_confusion_matrix_ticklabels_cat.png new file mode 100644 index 0000000..e341165 Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_ticklabels_cat.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_ticklabels_false.png b/tests/baseline/test_plot_confusion_matrix_ticklabels_false.png new file mode 100644 index 0000000..3a19ca5 Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_ticklabels_false.png differ diff --git a/tests/baseline/test_plot_confusion_matrix_ticklabels_n_labels.png b/tests/baseline/test_plot_confusion_matrix_ticklabels_n_labels.png new file mode 100644 index 0000000..4439306 Binary files /dev/null and b/tests/baseline/test_plot_confusion_matrix_ticklabels_n_labels.png differ diff --git a/tests/test_plot.py b/tests/test_plot.py index d0083a1..4b8973f 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -7,6 +7,7 @@ """ from collections import Counter import unittest +import unittest.mock import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -52,12 +53,17 @@ class TestPlot(unittest.TestCase): 9.65912261, 2.54053964, 7.31815866, 5.91692937, 2.78676838, 7.92586481, 2.31337877, 1.78432016, 9.55596989, 6.64471696, 3.33907423, 7.49321025, 7.14822795, 4.11686499, 2.40202043] - y_pred = [1.85161709, 1.33317135, 9.45246137, 7.91986758, 7.54877922, 9.71532022, 3.56777447, 7.88673475, 5.56090322, 2.78851836, 6.70636033, 2.67531555, 1.13061356, 8.29287223, 6.27275223, 2.49572863, 7.14305019, 8.53578604, 3.99890533, 2.35510298] + y_true_matrix = [1, 2, 3, 1, 2, 3, 1, 2, 3] + y_pred_matrix = [1, 2, 3, 2, 3, 1, 1, 2, 3] + + y_true_matrix_cat = ['dog', 'cat', 'bird', 'dog', 'cat', 'bird', 'dog', 'cat', 'bird'] + y_pred_matrix_cat = ['dog', 'cat', 'bird', 'cat', 'bird', 'dog', 'dog', 'cat', 'bird'] + # Timeseries for testing. AXIS = {'x': 'darkorange', 'y': 'green', 'z': 'steelblue'} @@ -80,6 +86,129 @@ class TestPlot(unittest.TestCase): dict_to_plot = {'Red': 15, 'Green': 50, 'Blue': 24} dict_to_plot_numerical = {1: 1_798, 2: 12_000, 3: 2_933} + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix(self): + """ + Test of the `plot_confusion_matrix` function. + """ + ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix) + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_labels_filter(self): + """ + Test of the `plot_confusion_matrix` function. + + Check with the `labels_filter` option. + """ + ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix, + labels_filter=[3, 1]) + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_normalize(self): + """ + Test of the `plot_confusion_matrix` function. + + Check with the `normalize` option. + """ + ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix, + normalize='all') + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_sample_weight(self): + """ + Test of the `plot_confusion_matrix` function. + + Check with the `sample_weigth` option. + """ + weights = range(1, len(self.y_true_matrix) + 1) + ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix, + sample_weight=weights) + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_stats_acc(self): + """ + Test of the `plot_confusion_matrix` function. + + Check with the `stats` option. + """ + ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix, + stats='accuracy') + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_stats_error(self): + """ + Test of the `plot_confusion_matrix` function. + + Check with the `stats` option when the key of the stat is wrong. + """ + # Check the error message using a mock. + with unittest.mock.patch('logging.Logger.error') as mock_logging: + ax = bplt.plot_confusion_matrix(self.y_true_matrix_cat, self.y_pred_matrix_cat, + stats='acc') + mock_logging.assert_called_with("Wrong key acc, possible values: " + "['precision', 'recall', 'f1-score', 'support'].") + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_stats_prec(self): + """ + Test of the `plot_confusion_matrix` function. + + Check with the `stats` option with precision for all classes. + """ + ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix, + stats='precision') + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_stats_fscore(self): + """ + Test of the `plot_confusion_matrix` function. + + Check with the `stats` option with categorical values. + """ + ax = bplt.plot_confusion_matrix(self.y_true_matrix_cat, self.y_pred_matrix_cat, + stats='f1-score') + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_ticklabels_cat(self): + """ + Test of the `plot_confusion_matrix` function. + + Use categorical predictions. + """ + ax = bplt.plot_confusion_matrix(self.y_true_matrix_cat, self.y_pred_matrix_cat) + + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_ticklabels_false(self): + """ + Test of the `plot_confusion_matrix` function. + + Check with the `ticklabels` option set to False. + """ + ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix, + ticklabels=False) + return ax.figure + + @pytest.mark.mpl_image_compare + def test_plot_confusion_matrix_ticklabels_n_labels(self): + """ + Test of the `plot_confusion_matrix` function. + + Check with the `ticklabels` option and print every 2 labels. + """ + ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix, + ticklabels=2) + return ax.figure + @pytest.mark.mpl_image_compare def test_plot_correlation(self): """ @@ -208,7 +337,7 @@ def test_plot_pie_counter(self): """ Test of the `plot_pie` function. """ - data = {k: v for k, v in self.counter_pie.most_common(4)} + data = dict(self.counter_pie.most_common(4)) ax = bplt.plot_pie(data, explode=0.01, title='', startangle=10) return ax.figure