Skip to content

Commit

Permalink
feat(plot-confusion-matrix): add function to plot a confusion matrix
Browse files Browse the repository at this point in the history
The confusion matrix is computed inside the plot function.
  • Loading branch information
axelfahy committed Apr 8, 2020
1 parent e91e4ca commit f8f262c
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 3 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ As of *v0.2*, plots are not yet tested in the travis build.

## Release History

* 0.2.7
* ADD: Function ``plot_pie`` to plot counter as a pie chart.
* ADD: Function ``plot_confusion_matrix`` to calculate and plot a confusion matrix.
* 0.2.6
* CHANGE: Use of ``Optional`` keyword from ``typing`` for optional arguments.
* ADD: Function ``plot_pca_explained_variance_ratio`` to plot the explained variance of PCA.
Expand Down
2 changes: 2 additions & 0 deletions bff/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Plot module of bff."""

from .plot import (
plot_confusion_matrix,
plot_correlation,
plot_counter,
plot_history,
Expand All @@ -14,6 +15,7 @@

# Public object of the module.
__all__ = [
'plot_confusion_matrix',
'plot_correlation',
'plot_counter',
'plot_history',
Expand Down
131 changes: 130 additions & 1 deletion bff/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import logging
from collections import Counter
from typing import Optional, Sequence, Tuple, Union
from typing import Any, Optional, Sequence, Tuple, Union
import matplotlib as mpl
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
Expand Down Expand Up @@ -65,6 +65,135 @@ def callback(axes):
return ax


def plot_confusion_matrix(y_true: Union[np.array, pd.Series, Sequence],
y_pred: Union[np.array, pd.Series, Sequence],
labels_filter: Optional[Union[np.array, Sequence]] = None,
ticklabels: Any = 'auto',
sample_weight: Optional[str] = None,
normalize: Optional[str] = None,
stats: Optional[str] = None,
title: str = 'Confusion matrix',
ax: Optional[plt.axes] = None,
rotation_xticks: Union[float, None] = 90,
rotation_yticks: Optional[float] = None,
figsize: Tuple[int, int] = (13, 10),
dpi: int = 80,
style: str = 'white') -> plt.axes:
"""
Plot the confusion matrix.
The confusion matrix is computed in the function.
Parameters
----------
y_true : np.array, pd.Series or Sequence
Actual values.
y_pred : np.array, pd.Series or Sequence
Predicted values by the model.
labels_filter : array-like of shape (n_classes,), default None
List of labels to index the matrix. This may be used to reorder or
select a subset of labels. If `None` is given, those that appear at
least once in `y_true` or `y_pred` are used in sorted order.
ticklabels : 'auto', bool, list-like, or int, default 'auto'
If True, plot the column names of the dataframe. If False, don’t plot the column names.
If list-like, plot these alternate labels as the xticklabels. If an integer,
use the column names but plot only every n label. If “auto”,
try to densely plot non-overlapping labels.
sample_weight : array-like of shape (n_samples,), optional
Sample weights.
normalize : str {'true', 'pred', 'all'}, optional
Normalizes confusion matrix over the true (rows), predicted (columns)
conditions or all the population. If None, confusion matrix will not be
normalized.
stats : str {'accuracy', 'precision', 'recall', 'f1-score'}, optional
Calculate and display the wanted statistic below the figure.
title : str, default 'Confusion matrix'
Title for the plot (axis level).
ax : plt.axes, optional
Axes from matplotlib, if None, new figure and axes will be created.
rotation_xticks : float or None, default 90
Rotation of x ticks if any.
rotation_yticks : float, optional
Rotation of x ticks if any.
Set to 90 to put them vertically.
figsize : Tuple[int, int], default (13, 10)
Size of the figure to plot.
dpi : int, default 80
Resolution of the figure.
style : str, default 'white'
Style to use for seaborn.axes_style.
The style is use only in this context and not applied globally.
Returns
-------
plt.axes
Axes returned by the `plt.subplots` function.
Examples
--------
>>> y_true = ['dog', 'cat', 'bird', 'cat', 'dog', 'dog']
>>> y_pred = ['cat', 'cat', 'bird', 'dog', 'bird', 'dog']
>>> plot_confusion_matrix(y_true, y_pred, stats='accuracy')
"""
bff.fancy._check_sklearn_support('plot_confusion_matrix')
from sklearn.metrics import classification_report, confusion_matrix

# Compute the confusion matrix.
cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight,
labels=labels_filter, normalize=normalize)

with sns.axes_style(style):
if ax is None:
__, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)

if ticklabels in (True, 'auto') and labels_filter is None:
ticklabels = sorted(set(list(y_true) + list(y_pred)))

# Draw the heatmap with the mask and correct aspect ratio.
sns.heatmap(cm, cmap=plt.cm.Blues, ax=ax, annot=True, square=True,
linewidths=0.5, cbar_kws={"shrink": 0.75},
xticklabels=ticklabels, yticklabels=ticklabels)

if stats:
report = classification_report(y_true, y_pred, labels=labels_filter,
sample_weight=sample_weight, output_dict=True)
if stats == 'accuracy':
ax.text(1.05, 0.05, f'{report[stats]:.2f}', horizontalalignment='left',
verticalalignment='center', transform=ax.transAxes)
else:
if ticklabels in (None, 'auto'):
ticklabels = report.keys()[:-3]

# Depending on the metric, there is one value by class.
# For each class, print the value of the metric.
for i, label in enumerate(ticklabels):
# Label might be an integer, cast to be sure.
label = str(label)

if stats in report[label].keys():
ax.text(1.05, 0.05 - (0.03 * i),
f'{label}: {report[label][stats]:.2f}',
horizontalalignment='left',
verticalalignment='center', transform=ax.transAxes)
else:
LOGGER.error(f'Wrong key {stats}, possible values: '
f'{list(report[label].keys())}.')
# Print the metric used.
if stats in report.keys() or stats in report[str(ticklabels[0])].keys():
ax.text(1.05, 0.08, f'{stats.capitalize()}', fontweight='bold',
horizontalalignment='left',
verticalalignment='center', transform=ax.transAxes)

ax.set_xlabel('Predicted label', fontsize=12)
ax.set_ylabel('True label', fontsize=12)
ax.set_title(title, fontsize=14)
# Style of the ticks.
plt.xticks(fontsize=12, alpha=1, rotation=rotation_xticks)
plt.yticks(fontsize=12, alpha=1, rotation=rotation_yticks)

return ax


def plot_correlation(df: pd.DataFrame,
already_computed: bool = False,
method: str = 'pearson',
Expand Down
Binary file added tests/baseline/test_plot_confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
133 changes: 131 additions & 2 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
from collections import Counter
import unittest
import unittest.mock
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -52,12 +53,17 @@ class TestPlot(unittest.TestCase):
9.65912261, 2.54053964, 7.31815866, 5.91692937, 2.78676838,
7.92586481, 2.31337877, 1.78432016, 9.55596989, 6.64471696,
3.33907423, 7.49321025, 7.14822795, 4.11686499, 2.40202043]

y_pred = [1.85161709, 1.33317135, 9.45246137, 7.91986758, 7.54877922,
9.71532022, 3.56777447, 7.88673475, 5.56090322, 2.78851836,
6.70636033, 2.67531555, 1.13061356, 8.29287223, 6.27275223,
2.49572863, 7.14305019, 8.53578604, 3.99890533, 2.35510298]

y_true_matrix = [1, 2, 3, 1, 2, 3, 1, 2, 3]
y_pred_matrix = [1, 2, 3, 2, 3, 1, 1, 2, 3]

y_true_matrix_cat = ['dog', 'cat', 'bird', 'dog', 'cat', 'bird', 'dog', 'cat', 'bird']
y_pred_matrix_cat = ['dog', 'cat', 'bird', 'cat', 'bird', 'dog', 'dog', 'cat', 'bird']

# Timeseries for testing.
AXIS = {'x': 'darkorange', 'y': 'green', 'z': 'steelblue'}

Expand All @@ -80,6 +86,129 @@ class TestPlot(unittest.TestCase):
dict_to_plot = {'Red': 15, 'Green': 50, 'Blue': 24}
dict_to_plot_numerical = {1: 1_798, 2: 12_000, 3: 2_933}

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix(self):
"""
Test of the `plot_confusion_matrix` function.
"""
ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix)
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_labels_filter(self):
"""
Test of the `plot_confusion_matrix` function.
Check with the `labels_filter` option.
"""
ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix,
labels_filter=[3, 1])
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_normalize(self):
"""
Test of the `plot_confusion_matrix` function.
Check with the `normalize` option.
"""
ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix,
normalize='all')
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_sample_weight(self):
"""
Test of the `plot_confusion_matrix` function.
Check with the `sample_weigth` option.
"""
weights = range(1, len(self.y_true_matrix) + 1)
ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix,
sample_weight=weights)
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_stats_acc(self):
"""
Test of the `plot_confusion_matrix` function.
Check with the `stats` option.
"""
ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix,
stats='accuracy')
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_stats_error(self):
"""
Test of the `plot_confusion_matrix` function.
Check with the `stats` option when the key of the stat is wrong.
"""
# Check the error message using a mock.
with unittest.mock.patch('logging.Logger.error') as mock_logging:
ax = bplt.plot_confusion_matrix(self.y_true_matrix_cat, self.y_pred_matrix_cat,
stats='acc')
mock_logging.assert_called_with("Wrong key acc, possible values: "
"['precision', 'recall', 'f1-score', 'support'].")
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_stats_prec(self):
"""
Test of the `plot_confusion_matrix` function.
Check with the `stats` option with precision for all classes.
"""
ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix,
stats='precision')
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_stats_fscore(self):
"""
Test of the `plot_confusion_matrix` function.
Check with the `stats` option with categorical values.
"""
ax = bplt.plot_confusion_matrix(self.y_true_matrix_cat, self.y_pred_matrix_cat,
stats='f1-score')
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_ticklabels_cat(self):
"""
Test of the `plot_confusion_matrix` function.
Use categorical predictions.
"""
ax = bplt.plot_confusion_matrix(self.y_true_matrix_cat, self.y_pred_matrix_cat)

return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_ticklabels_false(self):
"""
Test of the `plot_confusion_matrix` function.
Check with the `ticklabels` option set to False.
"""
ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix,
ticklabels=False)
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix_ticklabels_n_labels(self):
"""
Test of the `plot_confusion_matrix` function.
Check with the `ticklabels` option and print every 2 labels.
"""
ax = bplt.plot_confusion_matrix(self.y_true_matrix, self.y_pred_matrix,
ticklabels=2)
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_correlation(self):
"""
Expand Down Expand Up @@ -208,7 +337,7 @@ def test_plot_pie_counter(self):
"""
Test of the `plot_pie` function.
"""
data = {k: v for k, v in self.counter_pie.most_common(4)}
data = dict(self.counter_pie.most_common(4))
ax = bplt.plot_pie(data, explode=0.01, title='', startangle=10)
return ax.figure

Expand Down

0 comments on commit f8f262c

Please sign in to comment.