diff --git a/nutils/export.py b/nutils/export.py index c1494e958..d6d1b795e 100644 --- a/nutils/export.py +++ b/nutils/export.py @@ -47,34 +47,59 @@ def plotlines_(ax, xy, lines, **kwargs): return lc -def triplot(name, points, values=None, *, tri=None, hull=None, cmap=None, clim=None, linewidth=.1, linecolor='k', plabel=None, vlabel=None): - if (tri is None) != (values is None): +def _triplot_1d(ax, points, values=None, *, tri=None, hull=None, cmap=None, clim=None, linewidth=.1, linecolor='k', plabel=None, vlabel=None): + if plabel: + ax.set_xlabel(plabel) + if vlabel: + ax.set_ylabel(vlabel) + if hull is not None: + for x in points[hull[:, 0], 0]: + ax.axvline(x, color=linecolor, linewidth=linewidth) + if tri is not None: + plotlines_(ax, [points[:, 0], values], tri) + ax.autoscale(enable=True, axis='x', tight=True) + if clim is None: + ax.autoscale(enable=True, axis='y', tight=False) + else: + ax.set_ylim(clim) + + +def _triplot_2d(ax, points, values=None, *, tri=None, hull=None, cmap=None, clim=None, linewidth=.1, linecolor='k', plabel=None, vlabel=None): + ax.set_aspect('equal') + if plabel: + ax.set_xlabel(plabel) + ax.set_ylabel(plabel) + if tri is not None: + im = ax.tripcolor(*points.T, tri, values, shading='gouraud', cmap=cmap, rasterized=True) + if clim is not None: + im.set_clim(clim) + else: + im = None + if hull is not None: + plotlines_(ax, points.T, hull, colors=linecolor, linewidths=linewidth, alpha=1 if tri is None else .5) + ax.autoscale(enable=True, axis='both', tight=True) + return im + + +def triplot(name, points, values=None, **kwargs): + + if points.shape[1] == 1: + _triplot = _triplot_1d + elif points.shape[1] == 2: + _triplot = _triplot_2d + else: + raise Exception(f'invalid spatial dimension: {nd}') + + if (kwargs.get('tri') is None) != (values is None): raise Exception('tri and values can only be specified jointly') + + if not isinstance(name, str): + return _triplot(name, points, values, **kwargs) + with mplfigure(name) as fig: - if points.shape[1] == 1: - ax = fig.add_subplot(111, xlabel=plabel, ylabel=vlabel) - if tri is not None: - plotlines_(ax, [points[:, 0], values], tri) - if hull is not None: - for x in points[hull[:, 0], 0]: - ax.axvline(x, color=linecolor, linewidth=linewidth) - ax.autoscale(enable=True, axis='x', tight=True) - if clim is None: - ax.autoscale(enable=True, axis='y', tight=False) - else: - ax.set_ylim(clim) - elif points.shape[1] == 2: - ax = fig.add_subplot(111, xlabel=plabel, ylabel=plabel, aspect='equal') - if tri is not None: - im = ax.tripcolor(*points.T, tri, values, shading='gouraud', cmap=cmap, rasterized=True) - if clim is not None: - im.set_clim(clim) - fig.colorbar(im, label=vlabel) - if hull is not None: - plotlines_(ax, points.T, hull, colors=linecolor, linewidths=linewidth, alpha=1 if tri is None else .5) - ax.autoscale(enable=True, axis='both', tight=True) - else: - raise Exception('invalid spatial dimension: {}'.format(points.shape[1])) + im = _triplot(fig.add_subplot(111), points, values, **kwargs) + if im: + fig.colorbar(im, label=kwargs.get('vlabel')) @util.positional_only