From 8b2ad09b1f1b5353f88e085dc2e3cf7c96bad435 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Thu, 2 Dec 2021 16:35:14 +0100 Subject: [PATCH] restructure export.triplot This patch cleans up export.triplot without changing any functionality. --- nutils/export.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/nutils/export.py b/nutils/export.py index 474dbe65f..f7bce9f29 100644 --- a/nutils/export.py +++ b/nutils/export.py @@ -54,15 +54,20 @@ def mplfigure(name, kwargs=...): finally: fig.set_canvas(None) # break circular reference +def plotlines_(ax, xy, lines, **kwargs): + from matplotlib import collections + lc = collections.LineCollection(numpy.asarray(xy).T[lines], **kwargs) + ax.add_collection(lc) + return lc + def triplot(name, points, values=None, *, tri=None, hull=None, cmap=None, clim=None, linewidth=.1, linecolor='k'): if (tri is None) != (values is None): raise Exception('tri and values can only be specified jointly') with mplfigure(name) as fig: - ax = fig.add_subplot(111) if points.shape[1] == 1: + ax = fig.add_subplot(111) if tri is not None: - import matplotlib.collections - ax.add_collection(matplotlib.collections.LineCollection(numpy.array([points[:,0], values]).T[tri])) + plotlines_(ax, [points[:,0], values], tri) if hull is not None: for x in points[hull[:,0],0]: ax.axvline(x, color=linecolor, linewidth=linewidth) @@ -72,15 +77,14 @@ def triplot(name, points, values=None, *, tri=None, hull=None, cmap=None, clim=N else: ax.set_ylim(clim) elif points.shape[1] == 2: - ax.set_aspect('equal') + ax = fig.add_subplot(111, aspect='equal') if tri is not None: - im = ax.tripcolor(points[:,0], points[:,1], tri, values, shading='gouraud', cmap=cmap, rasterized=True) + im = ax.tripcolor(*points.T, tri, values, shading='gouraud', cmap=cmap, rasterized=True) if clim is not None: im.set_clim(clim) fig.colorbar(im) if hull is not None: - import matplotlib.collections - ax.add_collection(matplotlib.collections.LineCollection(points[hull], colors=linecolor, linewidths=linewidth, alpha=1 if tri is None else .5)) + plotlines_(ax, points.T, hull, colors=linecolor, linewidths=linewidth, alpha=1 if tri is None else .5) ax.autoscale(enable=True, axis='both', tight=True) else: raise Exception('invalid spatial dimension: {}'.format(points.shape[1]))