diff --git a/src/xtgeo/__init__.py b/src/xtgeo/__init__.py index 8b37d7bb9..0f3c411c1 100644 --- a/src/xtgeo/__init__.py +++ b/src/xtgeo/__init__.py @@ -51,37 +51,15 @@ def _xprint(msg): ROXAR = False -# to avoid problems in batch runs when no DISPLAY is set: -_xprint("Import matplotlib etc...") -if not ROXAR: - import matplotlib as mplib - - display = os.environ.get("DISPLAY", "") - host1 = os.environ.get("HOSTNAME", "") - host2 = os.environ.get("HOST", "") - dhost = host1 + host2 + display - - ertbool = "LSB_JOBID" in os.environ - - if display == "" or "grid" in dhost or "lgc" in dhost or ertbool: - _xprint("") - _xprint("=" * 79) - - _xprint( - "XTGeo info: No display found or a batch (e.g. ERT) server. " - "Using non-interactive Agg backend for matplotlib" - ) - mplib.use("Agg") - _xprint("=" * 79) - -# -# Order matters! -# -_xprint("Import matplotlib etc...DONE") - from xtgeo._cxtgeo import XTGeoCLibError from xtgeo.common import XTGeoDialog -from xtgeo.common.constants import UNDEF, UNDEF_INT, UNDEF_INT_LIMIT, UNDEF_LIMIT +from xtgeo.common.constants import ( + UNDEF, + UNDEF_INT, + UNDEF_INT_LIMIT, + UNDEF_LIMIT, + XTG_BATCH, +) from xtgeo.common.exceptions import ( BlockedWellsNotFoundError, DateNotFoundError, @@ -178,6 +156,25 @@ def _xprint(msg): polygons_from_wells, ) +if not ROXAR: + _display = os.environ.get("DISPLAY", "") + _hostname = os.environ.get("HOSTNAME", "") + _host = os.environ.get("HOST", "") + _dhost = _hostname + _host + _display + + _lsf_queue = "LSB_JOBID" in os.environ + + if _display == "" or "grid" in _dhost or "lgc" in _dhost or _lsf_queue: + _xprint("") + _xprint("=" * 79) + _xprint( + "XTGeo info: No display found or a batch (e.g. ERT) server. " + "Using non-interactive Agg backend for matplotlib" + ) + _xprint("=" * 79) + os.environ[XTG_BATCH] = "1" + + warnings.filterwarnings("default", category=DeprecationWarning, module="xtgeo") _xprint("XTGEO __init__ done") diff --git a/src/xtgeo/common/_import_mpl.py b/src/xtgeo/common/_import_mpl.py new file mode 100644 index 000000000..b053a5047 --- /dev/null +++ b/src/xtgeo/common/_import_mpl.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import os +import sys +from typing import Any + +from xtgeo.common.constants import XTG_BATCH + + +def get_matplotlib() -> Any: + if "matplotlib" not in sys.modules: + import matplotlib + + if XTG_BATCH in os.environ: + matplotlib.use("Agg") + return sys.modules["matplotlib"] + + +def get_pyplot() -> Any: + if "matplotlib" not in sys.modules: + # Set `use("Agg")` + get_matplotlib() + + if "matplotlib.pyplot" not in sys.modules: + # Must be imported separately + import matplotlib.pyplot + + matplotlib.pyplot # Let flake8 ignore unused import + return sys.modules["matplotlib.pyplot"] diff --git a/src/xtgeo/common/constants.py b/src/xtgeo/common/constants.py index 751035b5c..2f0fe843e 100644 --- a/src/xtgeo/common/constants.py +++ b/src/xtgeo/common/constants.py @@ -25,3 +25,5 @@ # for XYZ data, restricted to float32 and int32 UNDEF_CONT = UNDEF UNDEF_DISC = UNDEF_INT + +XTG_BATCH = "XTG_BATCH" diff --git a/src/xtgeo/plot/__init__.py b/src/xtgeo/plot/__init__.py index b967ebbef..f3b9f0557 100644 --- a/src/xtgeo/plot/__init__.py +++ b/src/xtgeo/plot/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """The XTGeo plot package""" @@ -7,5 +6,3 @@ # flake8: noqa from xtgeo.plot.xsection import XSection from xtgeo.plot.xtmap import Map - -# from ._colortables import random, random40, xtgeocolors, colorsfromfile diff --git a/src/xtgeo/plot/baseplot.py b/src/xtgeo/plot/baseplot.py index a95886f28..a577a5a11 100644 --- a/src/xtgeo/plot/baseplot.py +++ b/src/xtgeo/plot/baseplot.py @@ -1,10 +1,8 @@ """The baseplot module.""" -import matplotlib as mpl -import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap from packaging.version import parse as versionparse from xtgeo.common import XTGeoDialog, null_logger +from xtgeo.common._import_mpl import get_matplotlib, get_pyplot from . import _colortables as _ctable @@ -14,7 +12,9 @@ def _get_colormap(name): """For matplotlib compatibility.""" + mpl = get_matplotlib() if versionparse(mpl.__version__) < versionparse("3.6"): + plt = get_pyplot() return plt.cm.get_cmap(name) else: return mpl.colormaps[name] @@ -56,7 +56,8 @@ def colormap(self): @colormap.setter def colormap(self, cmap): - if isinstance(cmap, LinearSegmentedColormap): + mpl = get_matplotlib() + if isinstance(cmap, mpl.colors.LinearSegmentedColormap): self._colormap = cmap elif isinstance(cmap, str): logger.info("Definition of a colormap from string name: %s", cmap) @@ -85,6 +86,8 @@ def define_any_colormap(cfile, colorlist=None): from 0 index. Default is just keep the linear sequence as is. """ + mpl = get_matplotlib() + plt = get_matplotlib() valid_maps = sorted(m for m in plt.cm.datad) logger.info("Valid color maps: %s", valid_maps) @@ -99,21 +102,37 @@ def define_any_colormap(cfile, colorlist=None): elif cfile == "xtgeo": colors = _ctable.xtgeocolors() - cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + cfile, + colors, + N=len(colors), + ) cmap.name = "xtgeo" elif cfile == "random40": colors = _ctable.random40() - cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + cfile, + colors, + N=len(colors), + ) cmap.name = "random40" elif cfile == "randomc": colors = _ctable.randomc(256) - cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + cfile, + colors, + N=len(colors), + ) cmap.name = "randomc" elif isinstance(cfile, str) and "rms" in cfile: colors = _ctable.colorsfromfile(cfile) - cmap = LinearSegmentedColormap.from_list("rms", colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + "rms", + colors, + N=len(colors), + ) cmap.name = cfile elif cfile in valid_maps: cmap = _get_colormap(cfile) @@ -138,7 +157,11 @@ def define_any_colormap(cfile, colorlist=None): logger.warning("Color list out of range") ctable.append(colors[0]) - cmap = LinearSegmentedColormap.from_list(ctable, colors, N=len(colors)) + cmap = mpl.colors.LinearSegmentedColormap.from_list( + ctable, + colors, + N=len(colors), + ) cmap.name = "user" return cmap @@ -182,7 +205,7 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0): """ - # self._fig, (ax1, ax2) = plt.subplots(2, figsize=(11.69, 8.27)) + plt = get_pyplot() self._fig, self._ax = plt.subplots( figsize=(11.69 * figscaling, 8.27 * figscaling) ) @@ -205,6 +228,7 @@ def show(self): if self._showok: logger.info("Calling plt show method...") + plt = get_pyplot() plt.show() return True @@ -218,6 +242,7 @@ def close(self): After close is called, no more operations can be performed on the plot. """ + plt = get_pyplot() for fig in self._allfigs: plt.close(fig) @@ -247,6 +272,7 @@ def savefig(self, filename, fformat="png", last=True, **kwargs): self._fig.tight_layout() if self._showok: + plt = get_pyplot() plt.savefig(filename, format=fformat, **kwargs) if last: self.close() diff --git a/src/xtgeo/plot/grid3d_slice.py b/src/xtgeo/plot/grid3d_slice.py index d52a998a7..91a38a661 100644 --- a/src/xtgeo/plot/grid3d_slice.py +++ b/src/xtgeo/plot/grid3d_slice.py @@ -1,11 +1,7 @@ """Module for 3D Grid slice plots, using matplotlib.""" - -import matplotlib.pyplot as plt -from matplotlib.collections import PatchCollection -from matplotlib.patches import Polygon - from xtgeo.common import null_logger +from xtgeo.common._import_mpl import get_matplotlib, get_pyplot from xtgeo.plot.baseplot import BasePlot logger = null_logger(__name__) @@ -105,51 +101,6 @@ def plot_gridslice( else: self._plot_layer() - # def _plot_row(self): - - # geomlist = self._geomlist - - # if self._window is None: - # xmin = geomlist[3] - 0.05 * (abs(geomlist[4] - geomlist[3])) - # xmax = geomlist[4] + 0.05 * (abs(geomlist[4] - geomlist[3])) - # zmin = geomlist[7] - 0.05 * (abs(geomlist[8] - geomlist[7])) - # zmax = geomlist[8] + 0.05 * (abs(geomlist[8] - geomlist[7])) - # else: - # xmin, xmax, zmin, zmax = self._window - - # # now some numpy operations, numbering is intended - # clist = self._clist - # xz0 = np.column_stack((clist[0].values1d, clist[2].values1d)) - # xz1 = np.column_stack((clist[3].values1d, clist[5].values1d)) - # xz2 = np.column_stack((clist[15].values1d, clist[17].values1d)) - # xz3 = np.column_stack((clist[12].values1d, clist[14].values1d)) - - # xyc = np.column_stack((xz0, xz1, xz2, xz3)) - # xyc = xyc.reshape(self._grid.nlay, self._grid.ncol * self._grid.nrow, 4, 2) - - # patches = [] - - # for pos in range(self._grid.nrow * self._grid.nlay): - # nppol = xyc[self._index - 1, pos, :, :] - # if nppol.mean() > 0.0: - # polygon = Polygon(nppol, True) - # patches.append(polygon) - - # black = (0, 0, 0, 1) - # patchcoll = PatchCollection(patches, edgecolors=(black,), cmap=self.colormap) - - # # patchcoll.set_array(np.array(pvalues)) - - # # patchcoll.set_clim([minvalue, maxvalue]) - - # im = self._ax.add_collection(patchcoll) - # self._ax.set_xlim((xmin, xmax)) - # self._ax.set_ylim((zmin, zmax)) - # self._ax.invert_yaxis() - # self._fig.colorbar(im) - - # # plt.gca().set_aspect("equal", adjustable="box") - def _plot_layer(self): xyc, ibn = self._grid.get_layer_slice(self._index, activeonly=self._active) @@ -171,13 +122,14 @@ def _plot_layer(self): patches = [] + mpl = get_matplotlib() for pos in range(len(ibn)): nppol = xyc[pos, :, :] if nppol.mean() > 0.0: - polygon = Polygon(nppol) + polygon = mpl.patches.Polygon(nppol) patches.append(polygon) - patchcoll = PatchCollection( + patchcoll = mpl.collections.PatchCollection( patches, edgecolors=(self._linecolor,), cmap=self.colormap ) @@ -203,4 +155,5 @@ def _plot_layer(self): self._ax.set_ylim((ymin, ymax)) self._fig.colorbar(im) + plt = get_pyplot() plt.gca().set_aspect("equal", adjustable="box") diff --git a/src/xtgeo/plot/xsection.py b/src/xtgeo/plot/xsection.py index d439aa1ec..07bc5b33b 100644 --- a/src/xtgeo/plot/xsection.py +++ b/src/xtgeo/plot/xsection.py @@ -5,15 +5,13 @@ import warnings from collections import OrderedDict -import matplotlib.pyplot as plt import numpy as np import numpy.ma as ma import pandas as pd -from matplotlib import collections as mc -from matplotlib.lines import Line2D from scipy.ndimage import gaussian_filter from xtgeo.common import XTGeoDialog, null_logger +from xtgeo.common._import_mpl import get_matplotlib, get_pyplot from xtgeo.well import Well from xtgeo.xyz import Polygons @@ -265,6 +263,7 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0): """ # overriding the base class canvas + plt = get_pyplot() plt.rcParams["axes.xmargin"] = 0 # fill the plot margins @@ -445,6 +444,7 @@ def set_xaxis_md(self, gridlines=False): md_start_round = int(math.floor(md_start / 100.0)) * 100 md_start_delta = md_start - md_start_round + plt = get_pyplot() auto_ticks = plt.xticks() auto_ticks_delta = auto_ticks[0][1] - auto_ticks[0][0] @@ -566,7 +566,8 @@ def _plot_well_zlog(self, df, ax, bba, zonelogname, logwidth=4, legend=False): df, idx_zshift, ctable, zonelogname, fillnavalue ) - lc = mc.LineCollection( + mpl = get_matplotlib() + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=202 ) @@ -610,7 +611,8 @@ def _plot_well_faclog(self, df, ax, bba, facieslogname, logwidth=9, legend=True) df, idx, ctable, facieslogname, fillnavalue ) - lc = mc.LineCollection( + mpl = get_matplotlib() + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=201 ) @@ -656,7 +658,8 @@ def _plot_well_perflog(self, df, ax, bba, perflogname, logwidth=12, legend=True) df, idx, ctable, perflogname, fillnavalue ) - lc = mc.LineCollection( + mpl = get_matplotlib() + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=200 ) @@ -769,9 +772,10 @@ def _drawproxylegend(self, ax, bba, items, title=None): proxies = [] labels = [] + mpl = get_matplotlib() for item in items: color = items[item] - proxies.append(Line2D([0, 1], [0, 1], color=color, linewidth=5)) + proxies.append(mpl.lines.Line2D([0, 1], [0, 1], color=color, linewidth=5)) labels.append(item) ax.legend( diff --git a/src/xtgeo/plot/xtmap.py b/src/xtgeo/plot/xtmap.py index 9416b2280..5bd5ce3d7 100644 --- a/src/xtgeo/plot/xtmap.py +++ b/src/xtgeo/plot/xtmap.py @@ -1,13 +1,10 @@ """Module for map plots of surfaces, using matplotlib.""" - -import matplotlib.patches as mplp -import matplotlib.pyplot as plt import numpy as np import numpy.ma as ma -from matplotlib import ticker from xtgeo.common import null_logger +from xtgeo.common._import_mpl import get_matplotlib, get_pyplot from .baseplot import BasePlot @@ -121,6 +118,7 @@ def plot_surface( levels = np.linspace(minvalue, maxvalue, self.contourlevels) logger.debug("Number of contour levels: %s", levels) + plt = get_pyplot() plt.setp(self._ax.xaxis.get_majorticklabels(), rotation=xlabelrotation) # zi = ma.masked_where(zimask, zi) @@ -143,7 +141,8 @@ def plot_surface( else: logger.info("use LogLocator") - locator = ticker.LogLocator() + mpl = get_matplotlib() + locator = mpl.ticker.LogLocator() ticks = None uselevels = None im = self._ax.contourf(xi, yi, zi, locator=locator, cmap=self.colormap) @@ -177,6 +176,7 @@ def plot_faults( .. _Matplotlib: http://matplotlib.org/api/colors_api.html """ aff = fpoly.dataframe.groupby(idname) + mpl = get_matplotlib() for name, _group in aff: # make a dataframe sorted on faults (groupname) @@ -185,7 +185,13 @@ def plot_faults( # make a list [(X,Y) ...]; af = list(zip(myfault["X_UTME"].values, myfault["Y_UTMN"].values)) - px = mplp.Polygon(af, alpha=alpha, color=color, ec=edgecolor, lw=linewidth) + px = mpl.patches.Polygon( + af, + alpha=alpha, + color=color, + ec=edgecolor, + lw=linewidth, + ) if px.get_closed(): self._ax.add_artist(px) diff --git a/src/xtgeo/surface/_regsurf_oper.py b/src/xtgeo/surface/_regsurf_oper.py index deeb22564..d41c9aa57 100644 --- a/src/xtgeo/surface/_regsurf_oper.py +++ b/src/xtgeo/surface/_regsurf_oper.py @@ -7,11 +7,11 @@ import numpy as np import numpy.ma as ma -from matplotlib.path import Path as MPath import xtgeo from xtgeo import XTGeoCLibError, _cxtgeo from xtgeo.common import XTGeoDialog, null_logger +from xtgeo.common._import_mpl import get_matplotlib from xtgeo.xyz import Polygons xtg = XTGeoDialog() @@ -565,11 +565,12 @@ def _proxy_map_polygons(surf, poly, inside=True): xvals, yvals = proxy.get_xy_values(asmasked=False) points = np.array([xvals.ravel(), yvals.ravel()]).T + mpl = get_matplotlib() for pol in usepolys: idgroups = pol.dataframe.groupby(pol.pname) for _, grp in idgroups: singlepoly = np.array([grp[pol.xname].values, grp[pol.yname].values]).T - poly_path = MPath(singlepoly) + poly_path = mpl.path.Path(singlepoly) is_inside = poly_path.contains_points(points) is_inside = is_inside.reshape(proxy.ncol, proxy.nrow) proxy.values = np.where(is_inside, inside_value, proxy.values) diff --git a/src/xtgeo/xyz/_xyz_oper.py b/src/xtgeo/xyz/_xyz_oper.py index 1a68d020a..7b1a9760b 100644 --- a/src/xtgeo/xyz/_xyz_oper.py +++ b/src/xtgeo/xyz/_xyz_oper.py @@ -5,12 +5,12 @@ import numpy as np import pandas as pd import shapely.geometry as sg -from matplotlib.path import Path as MPath from scipy.interpolate import UnivariateSpline, interp1d import xtgeo from xtgeo import _cxtgeo from xtgeo.common import XTGeoDialog, null_logger +from xtgeo.common._import_mpl import get_matplotlib xtg = XTGeoDialog() logger = null_logger(__name__) @@ -40,11 +40,12 @@ def mark_in_polygons_mpl(self, poly, name, inside_value, outside_value): self.dataframe[name] = outside_value + mpl = get_matplotlib() for pol in usepolys: idgroups = pol.dataframe.groupby(pol.pname) for _, grp in idgroups: singlepoly = np.array([grp[pol.xname].values, grp[pol.yname].values]).T - poly_path = MPath(singlepoly) + poly_path = mpl.path.Path(singlepoly) is_inside = poly_path.contains_points(points) self.dataframe.loc[is_inside, name] = inside_value