diff --git a/src/xtgeo/__init__.py b/src/xtgeo/__init__.py index 8b37d7bb9..748e79147 100644 --- a/src/xtgeo/__init__.py +++ b/src/xtgeo/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # flake8: noqa # pylint: skip-file # type: ignore @@ -50,34 +49,23 @@ def _xprint(msg): except Exception: ROXAR = False - -# to avoid problems in batch runs when no DISPLAY is set: -_xprint("Import matplotlib etc...") if not ROXAR: - import matplotlib as mplib - - display = os.environ.get("DISPLAY", "") - host1 = os.environ.get("HOSTNAME", "") - host2 = os.environ.get("HOST", "") - dhost = host1 + host2 + display + _display = os.environ.get("DISPLAY", "") + _hostname = os.environ.get("HOSTNAME", "") + _host = os.environ.get("HOST", "") - ertbool = "LSB_JOBID" in os.environ + _dhost = _hostname + _host + _display + _lsf_job = "LSB_JOBID" in os.environ - if display == "" or "grid" in dhost or "lgc" in dhost or ertbool: + if _display == "" or "grid" in _dhost or "lgc" in _dhost or _lsf_job: _xprint("") _xprint("=" * 79) - _xprint( "XTGeo info: No display found or a batch (e.g. ERT) server. " "Using non-interactive Agg backend for matplotlib" ) - mplib.use("Agg") _xprint("=" * 79) - -# -# Order matters! -# -_xprint("Import matplotlib etc...DONE") + os.environ["MPLBACKEND"] = "Agg" from xtgeo._cxtgeo import XTGeoCLibError from xtgeo.common import XTGeoDialog diff --git a/src/xtgeo/plot/__init__.py b/src/xtgeo/plot/__init__.py index b967ebbef..f3b9f0557 100644 --- a/src/xtgeo/plot/__init__.py +++ b/src/xtgeo/plot/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """The XTGeo plot package""" @@ -7,5 +6,3 @@ # flake8: noqa from xtgeo.plot.xsection import XSection from xtgeo.plot.xtmap import Map - -# from ._colortables import random, random40, xtgeocolors, colorsfromfile diff --git a/src/xtgeo/plot/baseplot.py b/src/xtgeo/plot/baseplot.py index a95886f28..704b7becb 100644 --- a/src/xtgeo/plot/baseplot.py +++ b/src/xtgeo/plot/baseplot.py @@ -1,7 +1,4 @@ """The baseplot module.""" -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap from packaging.version import parse as versionparse from xtgeo.common import XTGeoDialog, null_logger @@ -14,7 +11,11 @@ def _get_colormap(name): """For matplotlib compatibility.""" + import matplotlib as mpl + if versionparse(mpl.__version__) < versionparse("3.6"): + import matplotlib.plt as plt + return plt.cm.get_cmap(name) else: return mpl.colormaps[name] @@ -56,7 +57,9 @@ def colormap(self): @colormap.setter def colormap(self, cmap): - if isinstance(cmap, LinearSegmentedColormap): + import matplotlib as mpl + + if isinstance(cmap, mpl.colors.LinearSegmentedColormap): self._colormap = cmap elif isinstance(cmap, str): logger.info("Definition of a colormap from string name: %s", cmap) @@ -85,6 +88,9 @@ def define_any_colormap(cfile, colorlist=None): from 0 index. Default is just keep the linear sequence as is. """ + import matplotlib as mpl + import matplotlib.pyplot as plt + valid_maps = sorted(m for m in plt.cm.datad) logger.info("Valid color maps: %s", valid_maps) @@ -99,21 +105,37 @@ def define_any_colormap(cfile, colorlist=None): elif cfile == "xtgeo": colors = _ctable.xtgeocolors() - cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + cfile, + colors, + N=len(colors), + ) cmap.name = "xtgeo" elif cfile == "random40": colors = _ctable.random40() - cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + cfile, + colors, + N=len(colors), + ) cmap.name = "random40" elif cfile == "randomc": colors = _ctable.randomc(256) - cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + cfile, + colors, + N=len(colors), + ) cmap.name = "randomc" elif isinstance(cfile, str) and "rms" in cfile: colors = _ctable.colorsfromfile(cfile) - cmap = LinearSegmentedColormap.from_list("rms", colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + "rms", + colors, + N=len(colors), + ) cmap.name = cfile elif cfile in valid_maps: cmap = _get_colormap(cfile) @@ -138,7 +160,11 @@ def define_any_colormap(cfile, colorlist=None): logger.warning("Color list out of range") ctable.append(colors[0]) - cmap = LinearSegmentedColormap.from_list(ctable, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + ctable, + colors, + N=len(colors), + ) cmap.name = "user" return cmap @@ -182,7 +208,8 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0): """ - # self._fig, (ax1, ax2) = plt.subplots(2, figsize=(11.69, 8.27)) + import matplotlib.pyplot as plt + self._fig, self._ax = plt.subplots( figsize=(11.69 * figscaling, 8.27 * figscaling) ) @@ -204,6 +231,8 @@ def show(self): self._fig.tight_layout() if self._showok: + import matplotlib.pyplot as plt + logger.info("Calling plt show method...") plt.show() return True @@ -218,6 +247,8 @@ def close(self): After close is called, no more operations can be performed on the plot. """ + import matplotlib.pyplot as plt + for fig in self._allfigs: plt.close(fig) @@ -247,6 +278,8 @@ def savefig(self, filename, fformat="png", last=True, **kwargs): self._fig.tight_layout() if self._showok: + import matplotlib.pyplot as plt + plt.savefig(filename, format=fformat, **kwargs) if last: self.close() diff --git a/src/xtgeo/plot/grid3d_slice.py b/src/xtgeo/plot/grid3d_slice.py index d52a998a7..a1d809738 100644 --- a/src/xtgeo/plot/grid3d_slice.py +++ b/src/xtgeo/plot/grid3d_slice.py @@ -1,10 +1,5 @@ """Module for 3D Grid slice plots, using matplotlib.""" - -import matplotlib.pyplot as plt -from matplotlib.collections import PatchCollection -from matplotlib.patches import Polygon - from xtgeo.common import null_logger from xtgeo.plot.baseplot import BasePlot @@ -105,51 +100,6 @@ def plot_gridslice( else: self._plot_layer() - # def _plot_row(self): - - # geomlist = self._geomlist - - # if self._window is None: - # xmin = geomlist[3] - 0.05 * (abs(geomlist[4] - geomlist[3])) - # xmax = geomlist[4] + 0.05 * (abs(geomlist[4] - geomlist[3])) - # zmin = geomlist[7] - 0.05 * (abs(geomlist[8] - geomlist[7])) - # zmax = geomlist[8] + 0.05 * (abs(geomlist[8] - geomlist[7])) - # else: - # xmin, xmax, zmin, zmax = self._window - - # # now some numpy operations, numbering is intended - # clist = self._clist - # xz0 = np.column_stack((clist[0].values1d, clist[2].values1d)) - # xz1 = np.column_stack((clist[3].values1d, clist[5].values1d)) - # xz2 = np.column_stack((clist[15].values1d, clist[17].values1d)) - # xz3 = np.column_stack((clist[12].values1d, clist[14].values1d)) - - # xyc = np.column_stack((xz0, xz1, xz2, xz3)) - # xyc = xyc.reshape(self._grid.nlay, self._grid.ncol * self._grid.nrow, 4, 2) - - # patches = [] - - # for pos in range(self._grid.nrow * self._grid.nlay): - # nppol = xyc[self._index - 1, pos, :, :] - # if nppol.mean() > 0.0: - # polygon = Polygon(nppol, True) - # patches.append(polygon) - - # black = (0, 0, 0, 1) - # patchcoll = PatchCollection(patches, edgecolors=(black,), cmap=self.colormap) - - # # patchcoll.set_array(np.array(pvalues)) - - # # patchcoll.set_clim([minvalue, maxvalue]) - - # im = self._ax.add_collection(patchcoll) - # self._ax.set_xlim((xmin, xmax)) - # self._ax.set_ylim((zmin, zmax)) - # self._ax.invert_yaxis() - # self._fig.colorbar(im) - - # # plt.gca().set_aspect("equal", adjustable="box") - def _plot_layer(self): xyc, ibn = self._grid.get_layer_slice(self._index, activeonly=self._active) @@ -171,13 +121,15 @@ def _plot_layer(self): patches = [] + import matplotlib as mpl + for pos in range(len(ibn)): nppol = xyc[pos, :, :] if nppol.mean() > 0.0: - polygon = Polygon(nppol) + polygon = mpl.patches.Polygon(nppol) patches.append(polygon) - patchcoll = PatchCollection( + patchcoll = mpl.collections.PatchCollection( patches, edgecolors=(self._linecolor,), cmap=self.colormap ) @@ -203,4 +155,6 @@ def _plot_layer(self): self._ax.set_ylim((ymin, ymax)) self._fig.colorbar(im) + import matplotlib.pyplot as plt + plt.gca().set_aspect("equal", adjustable="box") diff --git a/src/xtgeo/plot/xsection.py b/src/xtgeo/plot/xsection.py index d439aa1ec..8946cc900 100644 --- a/src/xtgeo/plot/xsection.py +++ b/src/xtgeo/plot/xsection.py @@ -5,12 +5,9 @@ import warnings from collections import OrderedDict -import matplotlib.pyplot as plt import numpy as np import numpy.ma as ma import pandas as pd -from matplotlib import collections as mc -from matplotlib.lines import Line2D from scipy.ndimage import gaussian_filter from xtgeo.common import XTGeoDialog, null_logger @@ -265,6 +262,7 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0): """ # overriding the base class canvas + import matplotlib.pyplot as plt plt.rcParams["axes.xmargin"] = 0 # fill the plot margins @@ -445,6 +443,8 @@ def set_xaxis_md(self, gridlines=False): md_start_round = int(math.floor(md_start / 100.0)) * 100 md_start_delta = md_start - md_start_round + import matplotlib.pyplot as plt + auto_ticks = plt.xticks() auto_ticks_delta = auto_ticks[0][1] - auto_ticks[0][0] @@ -566,7 +566,9 @@ def _plot_well_zlog(self, df, ax, bba, zonelogname, logwidth=4, legend=False): df, idx_zshift, ctable, zonelogname, fillnavalue ) - lc = mc.LineCollection( + import matplotlib as mpl + + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=202 ) @@ -610,7 +612,9 @@ def _plot_well_faclog(self, df, ax, bba, facieslogname, logwidth=9, legend=True) df, idx, ctable, facieslogname, fillnavalue ) - lc = mc.LineCollection( + import matplotlib as mpl + + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=201 ) @@ -656,7 +660,9 @@ def _plot_well_perflog(self, df, ax, bba, perflogname, logwidth=12, legend=True) df, idx, ctable, perflogname, fillnavalue ) - lc = mc.LineCollection( + import matplotlib as mpl + + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=200 ) @@ -769,9 +775,11 @@ def _drawproxylegend(self, ax, bba, items, title=None): proxies = [] labels = [] + import matplotlib as mpl + for item in items: color = items[item] - proxies.append(Line2D([0, 1], [0, 1], color=color, linewidth=5)) + proxies.append(mpl.lines.Line2D([0, 1], [0, 1], color=color, linewidth=5)) labels.append(item) ax.legend( diff --git a/src/xtgeo/plot/xtmap.py b/src/xtgeo/plot/xtmap.py index 9416b2280..cdf4cfb57 100644 --- a/src/xtgeo/plot/xtmap.py +++ b/src/xtgeo/plot/xtmap.py @@ -1,11 +1,7 @@ """Module for map plots of surfaces, using matplotlib.""" - -import matplotlib.patches as mplp -import matplotlib.pyplot as plt import numpy as np import numpy.ma as ma -from matplotlib import ticker from xtgeo.common import null_logger @@ -121,6 +117,8 @@ def plot_surface( levels = np.linspace(minvalue, maxvalue, self.contourlevels) logger.debug("Number of contour levels: %s", levels) + import matplotlib.pyplot as plt + plt.setp(self._ax.xaxis.get_majorticklabels(), rotation=xlabelrotation) # zi = ma.masked_where(zimask, zi) @@ -143,7 +141,9 @@ def plot_surface( else: logger.info("use LogLocator") - locator = ticker.LogLocator() + import matplotlib as mpl + + locator = mpl.ticker.LogLocator() ticks = None uselevels = None im = self._ax.contourf(xi, yi, zi, locator=locator, cmap=self.colormap) @@ -176,6 +176,8 @@ def plot_faults( .. _Matplotlib: http://matplotlib.org/api/colors_api.html """ + import matplotlib as mpl + aff = fpoly.dataframe.groupby(idname) for name, _group in aff: @@ -185,7 +187,13 @@ def plot_faults( # make a list [(X,Y) ...]; af = list(zip(myfault["X_UTME"].values, myfault["Y_UTMN"].values)) - px = mplp.Polygon(af, alpha=alpha, color=color, ec=edgecolor, lw=linewidth) + px = mpl.patches.Polygon( + af, + alpha=alpha, + color=color, + ec=edgecolor, + lw=linewidth, + ) if px.get_closed(): self._ax.add_artist(px) diff --git a/src/xtgeo/surface/_regsurf_oper.py b/src/xtgeo/surface/_regsurf_oper.py index deeb22564..c17d5fe85 100644 --- a/src/xtgeo/surface/_regsurf_oper.py +++ b/src/xtgeo/surface/_regsurf_oper.py @@ -7,7 +7,6 @@ import numpy as np import numpy.ma as ma -from matplotlib.path import Path as MPath import xtgeo from xtgeo import XTGeoCLibError, _cxtgeo @@ -565,11 +564,13 @@ def _proxy_map_polygons(surf, poly, inside=True): xvals, yvals = proxy.get_xy_values(asmasked=False) points = np.array([xvals.ravel(), yvals.ravel()]).T + import matplotlib as mpl + for pol in usepolys: idgroups = pol.dataframe.groupby(pol.pname) for _, grp in idgroups: singlepoly = np.array([grp[pol.xname].values, grp[pol.yname].values]).T - poly_path = MPath(singlepoly) + poly_path = mpl.path.Path(singlepoly) is_inside = poly_path.contains_points(points) is_inside = is_inside.reshape(proxy.ncol, proxy.nrow) proxy.values = np.where(is_inside, inside_value, proxy.values) diff --git a/src/xtgeo/xyz/_xyz_oper.py b/src/xtgeo/xyz/_xyz_oper.py index 1a68d020a..6cf030a2b 100644 --- a/src/xtgeo/xyz/_xyz_oper.py +++ b/src/xtgeo/xyz/_xyz_oper.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd import shapely.geometry as sg -from matplotlib.path import Path as MPath from scipy.interpolate import UnivariateSpline, interp1d import xtgeo @@ -40,11 +39,13 @@ def mark_in_polygons_mpl(self, poly, name, inside_value, outside_value): self.dataframe[name] = outside_value + import matplotlib as mpl + for pol in usepolys: idgroups = pol.dataframe.groupby(pol.pname) for _, grp in idgroups: singlepoly = np.array([grp[pol.xname].values, grp[pol.yname].values]).T - poly_path = MPath(singlepoly) + poly_path = mpl.path.Path(singlepoly) is_inside = poly_path.contains_points(points) self.dataframe.loc[is_inside, name] = inside_value diff --git a/tests/test_plot/test_matplotlib_import.py b/tests/test_plot/test_matplotlib_import.py new file mode 100644 index 000000000..f592cfb8f --- /dev/null +++ b/tests/test_plot/test_matplotlib_import.py @@ -0,0 +1,58 @@ +import os +import sys +from unittest import mock + + +def _clear_state(sys, os): + delete = [] + for module, _ in sys.modules.items(): + if module.startswith(("xtgeo", "matplotlib")): + delete.append(module) + + for module in delete: + del sys.modules[module] + + if "MPLBACKEND" in os.environ: + del os.environ["MPLBACKEND"] + + +@mock.patch.dict(sys.modules) +@mock.patch.dict(os.environ) +def test_that_mpl_dynamically_imports(): + _clear_state(sys, os) + import xtgeo # noqa + + assert "matplotlib" not in sys.modules + assert "matplotlib.pyplot" not in sys.modules + + from xtgeo.plot.baseplot import BasePlot + + assert "matplotlib" not in sys.modules + assert "matplotlib.pyplot" not in sys.modules + + baseplot = BasePlot() + + assert "matplotlib" in sys.modules + assert "matplotlib.pyplot" not in sys.modules + + baseplot.close() + + assert "matplotlib.pyplot" in sys.modules + + +@mock.patch.dict(sys.modules) +@mock.patch.dict(os.environ, {"LSB_JOBID": "1"}) +def test_that_agg_backend_set_when_lsf_job(): + _clear_state(sys, os) + import xtgeo # noqa + + assert os.environ.get("MPLBACKEND", "") == "Agg" + + +@mock.patch.dict(sys.modules) +@mock.patch.dict(os.environ, {"DISPLAY": "X"}) +def test_that_agg_backend_set_when_display_set(): + _clear_state(sys, os) + import xtgeo # noqa + + assert os.environ.get("MPLBACKEND", "") == ""