Skip to content

Commit

Permalink
Merge pull request #811 from OpenFreeEnergy/fix-pymbar-plotting
Browse files Browse the repository at this point in the history
FIx overlap matrix label issue
  • Loading branch information
IAlibay authored Apr 11, 2024
2 parents 796013f + b2f2c5c commit 590a6c1
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 1 deletion.
39 changes: 38 additions & 1 deletion openfe/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy.typing as npt
from openff.units import unit
from typing import Optional, Union
import warnings


def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes:
Expand All @@ -22,8 +23,29 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes:
-------
ax : matplotlib.axes.Axes
An Axes object to plot.
Raises
------
UserWarning
If any row or column exceeds a sum value of 1.01. This indicates
an incorrect overlap/probability matrix.
Notes
-----
Borrowed from `alchemlyb <https://github.com/alchemistry/alchemlyb/blob/master/src/alchemlyb/visualisation/mbar_matrix.py>`_
which itself borrows from `alchemical-analysis <https://github.com/MobleyLab/alchemical-analysis>`_.
"""
num_states = len(matrix)

# Check if any row or column isn't close to 1.0
# Throw a warning if it's the case
if (not np.allclose(matrix.sum(axis=0), 1.0) or
not np.allclose(matrix.sum(axis=1), 1.0)):
wmsg = ("Overlap/probability matrix exceeds a sum of 1.0 in one or "
"more columns or rows of the matrix. This indicates an "
"incorrect overlap/probability matrix.")
warnings.warn(wmsg)

fig, ax = plt.subplots(figsize=(num_states / 2, num_states / 2))
ax.axis('off')
for i in range(num_states):
Expand All @@ -32,7 +54,18 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes:
ax.axhline(y=i, ls="-", lw=0.5, color="k", alpha=0.25)
for j in range(num_states):
val = matrix[i, j]
val_str = "{:.2f}".format(val)[1:]

# Catch if 0.05 from 0 or 1
# https://github.com/OpenFreeEnergy/openfe/issues/806
if matrix[j, i] < 0.005:
# This replicates the same behaviour as alchemical-analysis & alchemlyb
# i.e. near-zero values will just not be annotated
val_str = ""
elif matrix[j, i] > 0.995:
val_str = "{:.2f}".format(matrix[j, i])[:4]
else:
val_str = "{:.2f}".format(matrix[j, i])[1:]

rel_prob = val / matrix.max()

# shade box
Expand Down Expand Up @@ -100,6 +133,10 @@ def plot_convergence(
-------
ax : matplotlib.axes.Axes
An Axes object to plot.
Notes
-----
Modified from `alchemical analysis <<https://github.com/MobleyLab/alchemical-analysis>>`_
"""
known_units = {
'kilojoule_per_mole': 'kJ/mol',
Expand Down
Empty file.
152 changes: 152 additions & 0 deletions openfe/tests/analysis/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import numpy as np
import matplotlib
import pytest
from openfe.analysis.plotting import (
plot_lambda_transition_matrix,
)


MBAR_HIGH_FLOAT_PREC = np.array([
[4.04963280e-01, 2.64851626e-01, 1.55960834e-01,
8.70071466e-02, 4.65819362e-02, 2.21166590e-02,
5.28613476e-19, 1.21332039e-39, 9.53847574e-47,
1.08409130e-49, 1.00129930e-50],
[2.64851626e-01, 2.38999336e-01, 1.90795611e-01,
1.39227718e-01, 9.35391162e-02, 5.40680743e-02,
4.14462223e-18, 9.51310340e-39, 7.47869291e-46,
8.49987577e-49, 7.85074074e-50],
[1.55960834e-01, 1.90795611e-01, 1.99215052e-01,
1.82450137e-01, 1.49240018e-01, 1.03819831e-01,
2.12518612e-17, 4.87791510e-38, 3.83475587e-45,
4.35837503e-48, 4.02552618e-49],
[8.70071466e-02, 1.39227718e-01, 1.82450137e-01,
2.03476122e-01, 2.00318604e-01, 1.69001754e-01,
7.63228831e-17, 1.75183030e-37, 1.37719525e-44,
1.56524525e-47, 1.44570756e-48],
[4.65819362e-02, 9.35391162e-02, 1.49240018e-01,
2.00318604e-01, 2.39174175e-01, 2.52627633e-01,
2.01399730e-16, 4.62270467e-37, 3.63412308e-44,
4.13034672e-47, 3.81491237e-48],
[2.21166590e-02, 5.40680743e-02, 1.03819831e-01,
1.69001754e-01, 2.52627633e-01, 3.79847531e-01,
3.76232213e-16, 8.63561442e-37, 6.78885803e-44,
7.71584694e-47, 7.12658815e-48],
[5.28613476e-19, 4.14462223e-18, 2.12518612e-17,
7.63228831e-17, 2.01399730e-16, 3.76232213e-16,
1.01326575e+00, 8.95334208e-03, 5.08819552e-07,
9.41453345e-09, 2.77171908e-09],
[1.21332039e-39, 9.51310340e-39, 4.87791510e-38,
1.75183030e-37, 4.62270467e-37, 8.63561442e-37,
8.95334208e-03, 8.92862646e-01, 1.07782005e-01,
9.27776485e-03, 3.34696054e-03],
[9.53847574e-47, 7.47869291e-46, 3.83475587e-45,
1.37719525e-44, 3.63412308e-44, 6.78885803e-44,
5.08819552e-07, 1.07782005e-01, 5.54361284e-01,
2.25806184e-01, 1.34272920e-01],
[1.08409130e-49, 8.49987577e-49, 4.35837503e-48,
1.56524525e-47, 4.13034672e-47, 7.71584694e-47,
9.41453345e-09, 9.27776485e-03, 2.25806184e-01,
3.94054662e-01, 3.93084315e-01],
[1.00129930e-50, 7.85074074e-50, 4.02552618e-49,
1.44570756e-48, 3.81491237e-48, 7.12658815e-48,
2.77171908e-09, 3.34696054e-03, 1.34272920e-01,
3.93084315e-01, 4.91518742e-01],
])


MBAR_HIGH_FLOAT_ABNORMAL = np.array([
[2.34151792e-001, 1.39888121e-001, 7.84874621e-002,
4.18336072e-002, 2.05969028e-002, 7.26433700e-003,
8.59588359e-069, 9.82810344e-154, 2.86854631e-174,
3.61851838e-179, 2.03347234e-180],
[1.39888121e-001, 1.33087532e-001, 1.06283463e-001,
7.48596022e-002, 4.71944621e-002, 2.09090412e-002,
2.61632960e-069, 9.87687147e-154, 3.04564890e-174,
1.01323581e-178, 5.71180108e-180],
[7.84874621e-002, 1.06283463e-001, 1.12085643e-001,
1.01164236e-001, 8.02528014e-002, 4.39486161e-002,
4.65970677e-070, 5.76131562e-154, 2.05138486e-174,
1.68711540e-178, 9.52136287e-180],
[4.18336072e-002, 7.48596022e-002, 1.01164236e-001,
1.15030824e-001, 1.12925630e-001, 7.64083221e-002,
5.56609447e-071, 2.25203208e-154, 1.10970982e-174,
1.88729763e-178, 1.06553350e-179],
[2.05969028e-002, 4.71944621e-002, 8.02528014e-002,
1.12925630e-001, 1.37008472e-001, 1.24243953e-001,
4.63825760e-072, 6.26556483e-155, 5.47382223e-175,
1.47688013e-178, 8.33933879e-180],
[7.26433700e-003, 2.09090412e-002, 4.39486161e-002,
7.64083221e-002, 1.24243953e-001, 2.49447953e-001,
2.15384511e-073, 9.78751994e-156, 1.89651311e-175,
6.46075564e-179, 3.64830252e-180],
[8.59588359e-069, 2.61632960e-069, 4.65970677e-070,
5.56609447e-071, 4.63825760e-072, 2.15384511e-073,
1.54873255e+000, 2.33585672e-002, 2.78272533e-006,
4.61271012e-008, 9.81085527e-009],
[9.82810344e-154, 9.87687147e-154, 5.76131562e-154,
2.25203208e-154, 6.26556483e-155, 9.78751994e-156,
2.33585672e-002, 1.36721112e+000, 1.64391352e-001,
1.39600025e-002, 4.59238353e-003],
[2.86854631e-174, 3.04564890e-174, 2.05138486e-174,
1.10970982e-174, 5.47382223e-175, 1.89651311e-175,
2.78272533e-006, 1.64391352e-001, 8.87029135e-001,
3.45990080e-001, 1.76251997e-001],
[3.61851838e-179, 1.01323581e-178, 1.68711540e-178,
1.88729763e-178, 1.47688013e-178, 6.46075564e-179,
4.61271012e-008, 1.39600025e-002, 3.45990080e-001,
6.27268856e-001, 5.86475486e-001],
[2.03347234e-180, 5.71180108e-180, 9.52136287e-180,
1.06553350e-179, 8.33933879e-180, 3.64830252e-180,
9.81085527e-009, 4.59238353e-003, 1.76251997e-001,
5.86475486e-001, 8.06379597e-001],
])

@pytest.mark.parametrize("matrix",
[MBAR_HIGH_FLOAT_PREC, MBAR_HIGH_FLOAT_ABNORMAL],
)
def test_mbar_overlap_plot_high_warn(matrix):
wmsg = "Overlap/probability matrix exceeds"
with pytest.warns(match=wmsg):
ax = plot_lambda_transition_matrix(matrix)
assert isinstance(ax, matplotlib.axes.Axes)


MBAR_OVERLAP_NORMAL = np.array([
[5.40693861e-01, 3.01639682e-01, 1.21556611e-01, 3.23719204e-02,
3.58389041e-03, 1.48665588e-04, 5.24585724e-06, 1.21058499e-07,
1.77358642e-09, 1.47315799e-11, 5.95535941e-14],
[3.01639682e-01, 3.31228861e-01, 2.43846627e-01, 1.08013742e-01,
1.43267020e-02, 8.83652857e-04, 5.81705578e-05, 2.49870089e-06,
6.30913374e-08, 8.22822509e-10, 5.45494512e-12],
[1.21556611e-01, 2.43846627e-01, 3.25955156e-01, 2.56738695e-01,
4.67483791e-02, 4.53599666e-03, 5.71365908e-04, 4.52123401e-05,
1.91542154e-06, 4.11083861e-08, 4.89327222e-10],
[3.23719204e-02, 1.08013742e-01, 2.56738695e-01, 4.09261028e-01,
1.58278950e-01, 2.79859170e-02, 6.41857330e-03, 8.67983163e-04,
6.08822529e-05, 2.26213546e-06, 4.64070520e-08],
[3.58389041e-03, 1.43267020e-02, 4.67483791e-02, 1.58278950e-01,
4.25994251e-01, 2.44226615e-01, 8.69755675e-02, 1.78548481e-02,
1.90728937e-03, 1.00948754e-04, 2.55919445e-06],
[1.48665588e-04, 8.83652857e-04, 4.53599666e-03, 2.79859170e-02,
2.44226615e-01, 3.71598157e-01, 2.43792854e-01, 8.89756349e-02,
1.63090400e-02, 1.47914160e-03, 6.43253565e-05],
[5.24585724e-06, 5.81705578e-05, 5.71365908e-04, 6.41857330e-03,
8.69755675e-02, 2.43792854e-01, 3.28134403e-01, 2.34652614e-01,
8.31489505e-02, 1.49645600e-02, 1.27769523e-03],
[1.21058499e-07, 2.49870089e-06, 4.52123401e-05, 8.67983163e-04,
1.78548481e-02, 8.89756349e-02, 2.34652614e-01, 3.26504455e-01,
2.32499832e-01, 8.43901256e-02, 1.42066757e-02],
[1.77358642e-09, 6.30913374e-08, 1.91542154e-06, 6.08822529e-05,
1.90728937e-03, 1.63090400e-02, 8.31489505e-02, 2.32499832e-01,
3.36922669e-01, 2.44443192e-01, 8.47061646e-02],
[1.47315799e-11, 8.22822509e-10, 4.11083861e-08, 2.26213546e-06,
1.00948754e-04, 1.47914160e-03, 1.49645600e-02, 8.43901256e-02,
2.44443192e-01, 3.65791191e-01, 2.88828537e-01],
[5.95535941e-14, 5.45494512e-12, 4.89327222e-10, 4.64070520e-08,
2.55919445e-06, 6.43253565e-05, 1.27769523e-03, 1.42066757e-02,
8.47061646e-02, 2.88828537e-01, 6.10913996e-01]
])

def test_mbar_overlap_plot():
ax = plot_lambda_transition_matrix(MBAR_OVERLAP_NORMAL)
assert isinstance(ax, matplotlib.axes.Axes)

0 comments on commit 590a6c1

Please sign in to comment.