Skip to content

Commit

Permalink
feat(plot-tsne): add function to plot t-SNE
Browse files Browse the repository at this point in the history
  • Loading branch information
axelfahy committed Apr 8, 2020
1 parent 171d093 commit 251d432
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ As of *v0.2*, plots are not yet tested in the travis build.
* 0.2.7
* ADD: Function ``plot_pie`` to plot counter as a pie chart.
* ADD: Function ``plot_confusion_matrix`` to calculate and plot a confusion matrix.
* ADD: Function ``plot_tsne`` to plot t-SNE results.
* 0.2.6
* CHANGE: Use of ``Optional`` keyword from ``typing`` for optional arguments.
* ADD: Function ``plot_pca_explained_variance_ratio`` to plot the explained variance of PCA.
Expand Down
2 changes: 2 additions & 0 deletions bff/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
plot_predictions,
plot_series,
plot_true_vs_pred,
plot_tsne,
set_thousands_separator,
)

Expand All @@ -24,5 +25,6 @@
'plot_predictions',
'plot_series',
'plot_true_vs_pred',
'plot_tsne',
'set_thousands_separator',
]
189 changes: 184 additions & 5 deletions bff/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections import Counter
from typing import Any, Optional, Sequence, Tuple, Union
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -161,9 +162,6 @@ def plot_confusion_matrix(y_true: Union[np.array, pd.Series, Sequence],
ax.text(1.05, 0.05, f'{report[stats]:.2f}', horizontalalignment='left',
verticalalignment='center', transform=ax.transAxes)
else:
if ticklabels in (None, 'auto'):
ticklabels = report.keys()[:-3]

# Depending on the metric, there is one value by class.
# For each class, print the value of the metric.
for i, label in enumerate(ticklabels):
Expand Down Expand Up @@ -294,7 +292,7 @@ def plot_counter(counter: Union[Counter, dict],
style: str = 'default',
**kwargs) -> plt.axes:
"""
Plot the values of a counter as an bar plot.
Plot the values of a counter as a bar plot.
Values above the ratio are written as text on top of the bar.
Expand Down Expand Up @@ -717,7 +715,8 @@ def format_label(percent, values):
return f'{percent:.1f}%\n({absolute:,})'

if colors is None:
colors = list(plt.cm.rainbow(np.linspace(0, 1, len(data))))
cmap = cm.get_cmap('rainbow')
colors = list(cmap(np.linspace(0, 1, len(data))))
else:
assert len(colors) == len(data), (
'The number of colors does not match the number of labels.')
Expand Down Expand Up @@ -1264,6 +1263,186 @@ def get_limit(limit, data, percent=5):
return ax_main if not with_histograms else (ax_main, ax_right, ax_bottom)


def plot_tsne(df: pd.DataFrame,
tsne_col_1: str = 'tsne_1',
tsne_col_2: str = 'tsne_2',
label_col: Optional[str] = None,
colors: Optional[Sequence[str]] = None,
labels: Optional[Union[str, Sequence[str]]] = None,
label_x: str = 'Dimension 1',
label_y: str = 'Dimension 2',
title: str = 't-SNE',
ax: Optional[plt.axes] = None,
loc: Union[str, int] = 'best',
s: Optional[Union[TNum, Sequence]] = mpl.rcParams['lines.markersize'] * 2,
figsize: Tuple[int, int] = (14, 7),
dpi: int = 80,
style: str = 'default',
**kwargs) -> plt.axes:
"""
Plot t-SNE clustering.
T-SNE must be already computed and stored inside two separate column of the DataFrame.
See the example for more information.
If there are some labels (and the `label_col` is given), there will be one color by label,
if there are no label, you can specify a list of colors to be applied for each point of data.
If no label and no color are provided, a default color will be used for all points.
Parameters
----------
df : pd.DataFrame
DataFrame with the data and the classes to plot.
tsne_col_1 : str, default 'tsne_1'
First column of the DataFrame containing the tsne values.
tsne_col_2 : str, default 'tsne_2'
Second column of the DataFrame containing the tsne values.
label_col : str, optional
Column of the DataFrame containing the labels of the data.
If given, there will be one color by label. Colors can be
provided with the `colors` argument.
colors : sequence of str, optional
Colors for each classes to plot or for each point of data if there is no label.
labels : str or sequence of str, optional
Labels of the plotted classes, must be in the same order as the real labels
and the same length. If no class, can be a single value.
label_x : str, default 'Dimension 1'
Label for x axis.
label_y : str, default 'Dimension 2'
Label for y axis.
title : str, default 't-SNE'
Title for the plot (axis level).
ax : plt.axes, optional
Axes from matplotlib, if None, new figure and axes will be created.
loc : str or int, default 'best'
Location of the legend on the plot.
Either the legend string or legend code are possible.
s : number or sequence
Size of the points on the graph. Default is the matplotlib markersize * 2.
figsize : Tuple[int, int], default (14, 7)
Size of the figure to plot.
dpi : int, default 80
Resolution of the figure.
style : str, default 'default'
Style to use for matplotlib.pyplot.
The style is use only in this context and not applied globally.
**kwargs
Additional keyword arguments to be passed to the
`plt.scatter` function from matplotlib.
Returns
-------
plt.axes
Axes returned by the `plt.subplots` function.
Examples
--------
>>> from sklearn import datasets, manifold
>>> X, y = datasets.make_circles(n_samples=300, factor=.5, noise=.05)
>>> df = pd.DataFrame(X).assign(label=y)
>>> tsne = manifold.TSNE()
>>> tsne_results = tsne.fit_transform(df.drop('label', axis='columns'))
>>> df_tsne = df[['label']].assign(tsne_1=tsne_results[:, 0], tsne_2=tsne_results[:, 1])
>>> plot_tsne(df_tsne, label_col='label', colors=['r', 'b'])
"""
assert tsne_col_1 in df.columns, (
f'DataFrame does not contain column: {tsne_col_1}')
assert tsne_col_2 in df.columns, (
f'DataFrame does not contain column: {tsne_col_2}')
if label_col is not None:
assert label_col in df.columns, (
f'DataFrame does not contain column: {label_col}')

with plt.style.context(style):
if ax is None:
__, ax = plt.subplots(figsize=figsize, dpi=dpi)

# Flag to determine if the legend must be printed.
print_legend = True

# If there is a label, use one color by label.
# If colors are not provided, or wrong number, create them.
if label_col is not None:
labels_unique = df[label_col].unique()
if colors is None:
cmap = cm.get_cmap('rainbow')
colors = list(cmap(np.linspace(0, 1, len(labels_unique))))
else:
colors = bff.value_2_list(colors)
if len(colors) != len(labels_unique):
LOGGER.warning(f'Number of colors does not match the number of labels '
f'({len(colors)}/{len(labels_unique)}), '
f'using last color for missing ones.')
colors = colors + [colors[-1]] * \
(len(labels_unique) - len(colors)) # type: ignore

if labels is not None:
labels = bff.value_2_list(labels)
if len(labels) < len(labels_unique):
LOGGER.warning(f'Not enough labels ({len(labels)}/{len(labels_unique)}).')
labels = labels + [None] * (len(labels_unique) - len(labels)) # type: ignore
else:
print_legend = False
labels = [None] * len(labels_unique) # type: ignore

for i, l in enumerate(df[label_col].unique()):
data_1 = df.query(f'{label_col} == @l')[tsne_col_1].values
data_2 = df.query(f'{label_col} == @l')[tsne_col_2].values
ax.scatter(
data_1,
data_2,
s=s,
c=np.array([colors[i]]), # type: ignore
label=labels[i], # type: ignore
lw=0.1,
alpha=1,
**kwargs
)
# If there is no label, plot all the points.
# If there are some colors, uses the colors, else plot all with the same color.
else:
print_legend = False
if colors is None or len(colors) != df.shape[0]:
colors = None
data_1 = df[tsne_col_1].values
data_2 = df[tsne_col_2].values
ax.scatter(
data_1,
data_2,
s=s,
c=colors,
label=labels,
lw=0.1,
alpha=1,
**kwargs
)

ax.set_xlabel(label_x, fontsize=12)
ax.set_ylabel(label_y, fontsize=12)
ax.set_title(title, fontsize=14)

# Style.
# Remove border on the top and right.
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Set alpha on remaining borders.
ax.spines['left'].set_alpha(0.4)
ax.spines['bottom'].set_alpha(0.4)

# Only show ticks on the left and bottom spines
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
# Style of ticks.
plt.xticks(fontsize=10, alpha=0.7)
plt.yticks(fontsize=10, alpha=0.7)

if print_legend:
ax.legend(loc=loc)

return ax


def set_thousands_separator(axes: plt.axes, which: str = 'both',
nb_decimals: int = 1) -> plt.axes:
"""
Expand Down
Binary file added tests/baseline/test_plot_tsne.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_plot_tsne_warning_colors.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_plot_tsne_warning_labels.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_plot_tsne_with_colors.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/baseline/test_plot_tsne_without_label.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
77 changes: 77 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from collections import Counter
import unittest
import unittest.mock
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import bff.plot as bplt

Expand Down Expand Up @@ -86,6 +89,13 @@ class TestPlot(unittest.TestCase):
dict_to_plot = {'Red': 15, 'Green': 50, 'Blue': 24}
dict_to_plot_numerical = {1: 1_798, 2: 12_000, 3: 2_933}

# Data for tsne.
X, y = datasets.make_circles(n_samples=30, factor=.5, noise=.05, random_state=42)
df = pd.DataFrame(X).assign(label=y)
tsne = TSNE(n_iter=250)
tsne_results = tsne.fit_transform(df.drop('label', axis='columns'))
df_tsne = df[['label']].assign(tsne_1=tsne_results[:, 0], tsne_2=tsne_results[:, 1])

@pytest.mark.mpl_image_compare
def test_plot_confusion_matrix(self):
"""
Expand Down Expand Up @@ -556,6 +566,73 @@ def test_plot_true_vs_pred_with_identity(self):
with_identity=True, marker='.', c='r')
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_tsne(self):
"""
Test of the `plot_tsne` function.
"""
ax = bplt.plot_tsne(self.df_tsne, label_col='label', labels=['Ok', 'Ko'])
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_tsne_with_colors(self):
"""
Test of the `plot_tsne` function.
Test with custom colors.
"""
ax = bplt.plot_tsne(self.df_tsne, label_col='label', colors=['r', 'b'])
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_tsne_without_label(self):
"""
Test of the `plot_tsne` function.
Test without label.
"""
ax = bplt.plot_tsne(self.df_tsne)
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_tsne_without_label_with_color(self):
"""
Test of the `plot_tsne` function.
Test without label but with custom colors.
"""
cmap = cm.get_cmap('rainbow')
colors = list(cmap(np.linspace(0, 1, self.df_tsne.shape[0])))
ax = bplt.plot_tsne(self.df_tsne, colors=colors, label_x='Dim x', label_y='Dim_y')
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_tsne_warning_colors(self):
"""
Test of the `plot_tsne` function.
Test the warning regarding the colors.
"""
# Check the error message using a mock.
with unittest.mock.patch('logging.Logger.warning') as mock_logging:
ax = bplt.plot_tsne(self.df_tsne, label_col='label', colors=['r'])
mock_logging.assert_called_with('Number of colors does not match the number '
'of labels (1/2), using last color for missing ones.')
return ax.figure

@pytest.mark.mpl_image_compare
def test_plot_tsne_warning_labels(self):
"""
Test of the `plot_tsne` function.
Test the warning regarding the labels.
"""
# Check the error message using a mock.
with unittest.mock.patch('logging.Logger.warning') as mock_logging:
ax = bplt.plot_tsne(self.df_tsne, label_col='label', labels='True')
mock_logging.assert_called_with('Not enough labels (1/2).')
return ax.figure

@pytest.mark.mpl_image_compare
def test_set_thousands_separator_both(self):
"""
Expand Down

0 comments on commit 251d432

Please sign in to comment.