From 91a5a1c4e0c7e94dd74c48c68bf5ff715f3bc917 Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Wed, 26 Apr 2017 16:42:27 -0700 Subject: [PATCH] Fixing highlight issue in heatmap --- gneiss/plot/_heatmap.py | 1 + gneiss/plot/_plot.py | 3 ++- gneiss/plot/tests/test_plot.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/gneiss/plot/_heatmap.py b/gneiss/plot/_heatmap.py index 4491441..afb6878 100644 --- a/gneiss/plot/_heatmap.py +++ b/gneiss/plot/_heatmap.py @@ -169,6 +169,7 @@ def _plot_highlights_dendrogram(ax_highlights, table, t, highlights): hcoords = [] for i, n in enumerate(highlights.index): node = t.find(n) + k, l, r = node._k, node._l, node._r ax_highlights.add_patch( diff --git a/gneiss/plot/_plot.py b/gneiss/plot/_plot.py index d7ead42..f36009f 100644 --- a/gneiss/plot/_plot.py +++ b/gneiss/plot/_plot.py @@ -457,7 +457,8 @@ def dendrogram_heatmap(output_dir: str, table: pd.DataFrame, tree: TreeNode, metadata: MetadataCategory, ndim=10, method='clr', color_map='viridis'): - nodes = [n.name for n in tree.levelorder()] + nodes = [n.name for n in tree.levelorder() if not n.is_tip()] + nlen = min(ndim, len(nodes)) highlights = pd.DataFrame([['#00FF00', '#FF0000']] * nlen, index=nodes[:nlen]) diff --git a/gneiss/plot/tests/test_plot.py b/gneiss/plot/tests/test_plot.py index 396cb53..0ed6273 100644 --- a/gneiss/plot/tests/test_plot.py +++ b/gneiss/plot/tests/test_plot.py @@ -317,6 +317,36 @@ def test_visualization(self): self.assertIn('

Dendrogram heatmap

', html) + def test_visualization_small(self): + # tests the scenario where ndim > number of tips + np.random.seed(0) + num_otus = 11 # otus + table = pd.DataFrame(np.random.random((num_otus, 5)), + index=np.arange(num_otus).astype(np.str)).T + + x = np.random.rand(num_otus) + dm = DistanceMatrix.from_iterable(x, lambda x, y: np.abs(x-y)) + lm = ward(dm.condensed_form()) + t = TreeNode.from_linkage_matrix(lm, np.arange(len(x)).astype(np.str)) + + for i, n in enumerate(t.postorder()): + if not n.is_tip(): + n.name = "y%d" % i + n.length = np.random.rand()*3 + + md = MetadataCategory( + pd.Series(['a', 'a', 'a', 'b', 'b'])) + + dendrogram_heatmap(self.results, table, t, md) + + index_fp = os.path.join(self.results, 'index.html') + self.assertTrue(os.path.exists(index_fp)) + + with open(index_fp, 'r') as fh: + html = fh.read() + self.assertIn('

Dendrogram heatmap

', + html) + if __name__ == "__main__": unittest.main()