Skip to content

Commit

Permalink
tests: more for draw_hyperedges
Browse files Browse the repository at this point in the history
  • Loading branch information
maximelucas committed Oct 16, 2023
1 parent c3e6c1d commit 29df5a3
Showing 1 changed file with 100 additions and 1 deletion.
101 changes: 100 additions & 1 deletion tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
import seaborn as sb

import xgi
from xgi.exception import XGIError
Expand Down Expand Up @@ -32,6 +33,31 @@ def test_draw(edgelist8):

plt.close()

# simplicial complex
S = xgi.SimplicialComplex(edgelist8)

fig, ax = plt.subplots()
ax, collections = xgi.draw(S, ax=ax)
(node_collection, dyad_collection, edge_collection) = collections

# number of elements
assert len(ax.lines) == 0
assert len(ax.patches) == 0
offsets = node_collection.get_offsets()
assert offsets.shape[0] == S.num_nodes # nodes
assert len(ax.collections) == 3

assert len(dyad_collection.get_paths()) == 16 # dyads
assert len(edge_collection.get_paths()) == 3 # other hyperedges

# zorder
for line in ax.lines: # dyads
assert line.get_zorder() == 3
for patch, z in zip(ax.patches, [0, 2, 2]): # hyperedges
assert patch.get_zorder() == z

plt.close()


def test_draw_nodes(edgelist8):

Expand Down Expand Up @@ -204,8 +230,10 @@ def test_draw_hyperedges(edgelist8):

fig, ax = plt.subplots()
ax, collections = xgi.draw_hyperedges(H, ax=ax)

(dyad_collection, edge_collection) = collections
fig2, ax2 = plt.subplots()
ax2, collections2 = xgi.draw_hyperedges(H, ax=ax2, dyad_color="r", edge_fc="r", dyad_lw=3, dyad_style="--")
(dyad_collection2, edge_collection2) = collections2

# number of elements
assert len(ax.lines) == 0
Expand All @@ -220,6 +248,77 @@ def test_draw_hyperedges(edgelist8):
for patch, z in zip(ax.patches, [2, 2, 0, 2, 2]): # hyperedges
assert patch.get_zorder() == z

# dyad_style
dyad_collection.get_linestyle() == [(0.0, None)]
dyad_collection2.get_linestyle() == [(0.0, [5.550000000000001, 2.4000000000000004])]

# dyad_fc
assert np.all(
dyad_collection.get_color() == np.array([[0, 0, 0, 1]])
) # black
assert np.all(
dyad_collection2.get_color() == np.array([[1, 0, 0, 1]])
) # black

# edge_fc
assert np.all(
edge_collection.get_facecolor()[:, -1] == np.array([0.4, 0.4, 0.4, 0.4, 0.4, 0.4])
)
assert np.all(
edge_collection2.get_facecolor() == np.array([[1., 0., 0., 0.4]])
)

# edge_lw
assert np.all(dyad_collection.get_linewidth() == np.array([1.5]))
assert np.all(dyad_collection2.get_linewidth() == np.array([3]))
assert np.all(edge_collection.get_linewidth() == np.array([1.]))

# negative node_lw or node_size
with pytest.raises(ValueError):
ax, collections = xgi.draw_hyperedges(H, ax=ax, dyad_lw=-1)
(dyad_collection, edge_collection) = collections
plt.close()


plt.close("all")


def test_draw_hyperedges_fc_cmap(edgelist8):

H = xgi.Hypergraph(edgelist8)

# default cmap
fig, ax = plt.subplots()
ax, collections = xgi.draw_hyperedges(H, ax=ax)
(dyad_collection, edge_collection) = collections
assert dyad_collection.get_cmap() == plt.cm.Greys
assert edge_collection.get_cmap() == sb.color_palette("crest_r", as_cmap=True)
plt.close()

# set cmap
fig, ax = plt.subplots()
dyad_colors = [1, 3, 5]
ax, collections = xgi.draw_hyperedges(H, ax=ax, dyad_color=dyad_colors, dyad_color_cmap="Greens", edge_fc_cmap="Blues")
(dyad_collection, edge_collection) = collections
assert dyad_collection.get_cmap() == plt.cm.Greens
assert edge_collection.get_cmap() == plt.cm.Blues

plt.colorbar(dyad_collection)
plt.colorbar(edge_collection)

assert (min(dyad_colors), max(dyad_colors)) == dyad_collection.get_clim()
assert (3, 5) == edge_collection.get_clim()
plt.close()

# vmin/vmax
fig, ax = plt.subplots()
ax, collections = xgi.draw_hyperedges(H, ax=ax, dyad_color=dyad_colors, dyad_vmin=5, dyad_vmax=6, edge_vmin=14, edge_vmax=19)
(dyad_collection, edge_collection) = collections
plt.colorbar(dyad_collection)
plt.colorbar(edge_collection)
assert (14, 19) == edge_collection.get_clim()
assert (5, 6) == dyad_collection.get_clim()

plt.close()


Expand Down

0 comments on commit 29df5a3

Please sign in to comment.