Skip to content

Commit

Permalink
Merge pull request #108 from jeiros/add-trace2d-plot
Browse files Browse the repository at this point in the history
[WIP] Add trace2d plot
  • Loading branch information
cxhernandez authored Sep 29, 2017
2 parents d947387 + daa9c18 commit d62694c
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 2 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ New Features
- ``plot_free_energy`` now accepts a ``return_data`` flag that will return
the data used for the free energy plot(#78).

- Added a new function ``plot_trace2d`` that plots the time evolution of a 2D numpy array
using a colorbar to map the time (#108).

Improvements
~~~~~~~~~~~~

Expand Down
50 changes: 50 additions & 0 deletions examples/plot_trace2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Two dimensional trace plot
===============
"""
from msmbuilder.example_datasets import FsPeptide
from msmbuilder.featurizer import DihedralFeaturizer
from msmbuilder.decomposition import tICA
from msmbuilder.cluster import MiniBatchKMeans
from msmbuilder.msm import MarkovStateModel
from matplotlib import pyplot as pp
import numpy as np

import msmexplorer as msme

rs = np.random.RandomState(42)

# Load Fs Peptide Data
trajs = FsPeptide().get().trajectories

# Extract Backbone Dihedrals
featurizer = DihedralFeaturizer(types=['phi', 'psi'])
diheds = featurizer.fit_transform(trajs)

# Perform Dimensionality Reduction
tica_model = tICA(lag_time=2, n_components=2)
tica_trajs = tica_model.fit_transform(diheds)

# Plot free 2D free energy (optional)
txx = np.concatenate(tica_trajs, axis=0)
ax = msme.plot_free_energy(
txx, obs=(0, 1), n_samples=100000,
random_state=rs,
shade=True,
clabel=True,
clabel_kwargs={'fmt': '%.1f'},
cbar=True,
cbar_kwargs={'format': '%.1f', 'label': 'Free energy (kcal/mol)'}
)
# Now plot the first trajectory on top of it to inspect it's movement
msme.plot_trace2d(
data=tica_trajs[0], ts=0.2, ax=ax,
scatter_kwargs={'s': 2},
cbar_kwargs={'format': '%d', 'label': 'Time (ns)',
'orientation': 'horizontal'},
xlabel='tIC 1', ylabel='tIC 2'
)
# Finally, let's plot every trajectory to see the individual sampled regions
f, ax = pp.subplots()
msme.plot_trace2d(tica_trajs, ax=ax, xlabel='tIC 1', ylabel='tIC 2')
pp.show()
79 changes: 78 additions & 1 deletion msmexplorer/plots/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..utils import msme_colors
from .. import palettes

__all__ = ['plot_chord', 'plot_stackdist', 'plot_trace']
__all__ = ['plot_chord', 'plot_stackdist', 'plot_trace', 'plot_trace2d']


def plot_chord(data, ax=None, cmap=None, labels=None, labelsize=12, norm=True,
Expand Down Expand Up @@ -303,3 +303,80 @@ def plot_trace(data, label=None, window=1, ax=None, side_ax=None,
side_ax.set_title('')

return ax, side_ax


@msme_colors
def plot_trace2d(data, obs=(0, 1), ts=1.0, cbar=True, ax=None, xlabel=None,
ylabel=None, labelsize=14,
cbar_kwargs=None, scatter_kwargs=None, plot_kwargs=None):
"""
Plot a 2D trace of time-series data.
Parameters
----------
data : array-like (nsamples, 2) or list thereof
The samples. This should be a single 2-D time-series array or a list of 2-D
time-series arrays.
If it is a single 2D np.array, the elements will be scatter plotted and
color mapped to their values.
If it is a list of 2D np.arrays, each will be plotted with a single color on
the same axis.
obs: tuple, optional (default: (0,1))
Observables to plot.
ts: float, optional (default: 1.0)
Step in units of time between each data point in data
cbar: bool, optional (default: True)
Adds a colorbar that maps the evolution of points in data
ax : matplotlib axis, optional
main matplotlib figure axis for trace.
xlabel : str, optional
x-axis label
ylabel : str, optional
y-axis label
labelsize : int, optional (default: 14)
Font side for axes labels.
cbar_kwargs: dict, optional
Arguments to pass to matplotlib cbar
scatter_kwargs: dict, optional
Arguments to pass to matplotlib scatter
plot_kwargs: dict, optional
Arguments to pass to matplotlib plot
Returns
-------
ax : matplotlib axis
main matplotlib figure axis for 2D trace.
"""

if ax is None:
ax = pp.gca()
if scatter_kwargs is None:
scatter_kwargs = {}
if plot_kwargs is None:
plot_kwargs = {}

if not isinstance(obs, tuple):
raise ValueError('obs must be a tuple')

if isinstance(data, list):
# Plot each item in the list with a single color and join with lines
for item in data:
prune = item[:, obs]
ax.plot(prune[:, 0], prune[:, 1], **plot_kwargs)
else:
# A single array of data is passed, so we scatter plot
prune = data[:, obs]
c = ax.scatter(prune[:, 0], prune[:, 1],
c=np.linspace(0, data.shape[0] * ts, data.shape[0]),
**scatter_kwargs)
if cbar:
# Map the time evolution between the data points to a colorbar
if cbar_kwargs is None:
cbar_kwargs = {}
pp.colorbar(c, **cbar_kwargs)

if xlabel:
ax.set_xlabel(xlabel, size=labelsize)
if ylabel:
ax.set_ylabel(ylabel, size=labelsize)

return ax
10 changes: 9 additions & 1 deletion msmexplorer/tests/test_misc_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from matplotlib.axes import SubplotBase
from seaborn.apionly import FacetGrid

from ..plots import plot_chord, plot_stackdist, plot_trace
from ..plots import plot_chord, plot_stackdist, plot_trace, plot_trace2d
from . import PlotTestCase

rs = np.random.RandomState(42)
data = rs.rand(12, 12)
ts = rs.rand(100000, 1)
ts2 = rs.rand(100000, 2)


class TestChordPlot(PlotTestCase):
Expand Down Expand Up @@ -38,3 +39,10 @@ def test_plot_trace(self):

assert isinstance(ax, SubplotBase)
assert isinstance(side_ax, SubplotBase)

def test_plot_trace2d(self):
ax1 = plot_trace2d(ts2)
ax2 = plot_trace2d([ts2, ts2])

assert isinstance(ax1, SubplotBase)
assert isinstance(ax2, SubplotBase)

0 comments on commit d62694c

Please sign in to comment.