From 29df5a3a6d9050a3688b1f019106e8ee2a36e760 Mon Sep 17 00:00:00 2001 From: Maxime Lucas Date: Mon, 16 Oct 2023 11:50:08 +0200 Subject: [PATCH] tests: more for draw_hyperedges --- tests/drawing/test_draw.py | 101 ++++++++++++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 1 deletion(-) diff --git a/tests/drawing/test_draw.py b/tests/drawing/test_draw.py index 24a3b7755..a8599a914 100644 --- a/tests/drawing/test_draw.py +++ b/tests/drawing/test_draw.py @@ -1,6 +1,7 @@ import matplotlib.pyplot as plt import numpy as np import pytest +import seaborn as sb import xgi from xgi.exception import XGIError @@ -32,6 +33,31 @@ def test_draw(edgelist8): plt.close() + # simplicial complex + S = xgi.SimplicialComplex(edgelist8) + + fig, ax = plt.subplots() + ax, collections = xgi.draw(S, ax=ax) + (node_collection, dyad_collection, edge_collection) = collections + + # number of elements + assert len(ax.lines) == 0 + assert len(ax.patches) == 0 + offsets = node_collection.get_offsets() + assert offsets.shape[0] == S.num_nodes # nodes + assert len(ax.collections) == 3 + + assert len(dyad_collection.get_paths()) == 16 # dyads + assert len(edge_collection.get_paths()) == 3 # other hyperedges + + # zorder + for line in ax.lines: # dyads + assert line.get_zorder() == 3 + for patch, z in zip(ax.patches, [0, 2, 2]): # hyperedges + assert patch.get_zorder() == z + + plt.close() + def test_draw_nodes(edgelist8): @@ -204,8 +230,10 @@ def test_draw_hyperedges(edgelist8): fig, ax = plt.subplots() ax, collections = xgi.draw_hyperedges(H, ax=ax) - (dyad_collection, edge_collection) = collections + fig2, ax2 = plt.subplots() + ax2, collections2 = xgi.draw_hyperedges(H, ax=ax2, dyad_color="r", edge_fc="r", dyad_lw=3, dyad_style="--") + (dyad_collection2, edge_collection2) = collections2 # number of elements assert len(ax.lines) == 0 @@ -220,6 +248,77 @@ def test_draw_hyperedges(edgelist8): for patch, z in zip(ax.patches, [2, 2, 0, 2, 2]): # hyperedges assert patch.get_zorder() == z + # dyad_style + dyad_collection.get_linestyle() == [(0.0, None)] + dyad_collection2.get_linestyle() == [(0.0, [5.550000000000001, 2.4000000000000004])] + + # dyad_fc + assert np.all( + dyad_collection.get_color() == np.array([[0, 0, 0, 1]]) + ) # black + assert np.all( + dyad_collection2.get_color() == np.array([[1, 0, 0, 1]]) + ) # black + + # edge_fc + assert np.all( + edge_collection.get_facecolor()[:, -1] == np.array([0.4, 0.4, 0.4, 0.4, 0.4, 0.4]) + ) + assert np.all( + edge_collection2.get_facecolor() == np.array([[1., 0., 0., 0.4]]) + ) + + # edge_lw + assert np.all(dyad_collection.get_linewidth() == np.array([1.5])) + assert np.all(dyad_collection2.get_linewidth() == np.array([3])) + assert np.all(edge_collection.get_linewidth() == np.array([1.])) + + # negative node_lw or node_size + with pytest.raises(ValueError): + ax, collections = xgi.draw_hyperedges(H, ax=ax, dyad_lw=-1) + (dyad_collection, edge_collection) = collections + plt.close() + + + plt.close("all") + + +def test_draw_hyperedges_fc_cmap(edgelist8): + + H = xgi.Hypergraph(edgelist8) + + # default cmap + fig, ax = plt.subplots() + ax, collections = xgi.draw_hyperedges(H, ax=ax) + (dyad_collection, edge_collection) = collections + assert dyad_collection.get_cmap() == plt.cm.Greys + assert edge_collection.get_cmap() == sb.color_palette("crest_r", as_cmap=True) + plt.close() + + # set cmap + fig, ax = plt.subplots() + dyad_colors = [1, 3, 5] + ax, collections = xgi.draw_hyperedges(H, ax=ax, dyad_color=dyad_colors, dyad_color_cmap="Greens", edge_fc_cmap="Blues") + (dyad_collection, edge_collection) = collections + assert dyad_collection.get_cmap() == plt.cm.Greens + assert edge_collection.get_cmap() == plt.cm.Blues + + plt.colorbar(dyad_collection) + plt.colorbar(edge_collection) + + assert (min(dyad_colors), max(dyad_colors)) == dyad_collection.get_clim() + assert (3, 5) == edge_collection.get_clim() + plt.close() + + # vmin/vmax + fig, ax = plt.subplots() + ax, collections = xgi.draw_hyperedges(H, ax=ax, dyad_color=dyad_colors, dyad_vmin=5, dyad_vmax=6, edge_vmin=14, edge_vmax=19) + (dyad_collection, edge_collection) = collections + plt.colorbar(dyad_collection) + plt.colorbar(edge_collection) + assert (14, 19) == edge_collection.get_clim() + assert (5, 6) == dyad_collection.get_clim() + plt.close()