diff --git a/src/xtgeo/__init__.py b/src/xtgeo/__init__.py index 0f3c411c1..dc82de6a3 100644 --- a/src/xtgeo/__init__.py +++ b/src/xtgeo/__init__.py @@ -53,13 +53,7 @@ def _xprint(msg): from xtgeo._cxtgeo import XTGeoCLibError from xtgeo.common import XTGeoDialog -from xtgeo.common.constants import ( - UNDEF, - UNDEF_INT, - UNDEF_INT_LIMIT, - UNDEF_LIMIT, - XTG_BATCH, -) +from xtgeo.common.constants import UNDEF, UNDEF_INT, UNDEF_INT_LIMIT, UNDEF_LIMIT from xtgeo.common.exceptions import ( BlockedWellsNotFoundError, DateNotFoundError, @@ -172,7 +166,7 @@ def _xprint(msg): "Using non-interactive Agg backend for matplotlib" ) _xprint("=" * 79) - os.environ[XTG_BATCH] = "1" + os.environ["MPLBACKEND"] = "Agg" warnings.filterwarnings("default", category=DeprecationWarning, module="xtgeo") diff --git a/src/xtgeo/common/_import_mpl.py b/src/xtgeo/common/_import_mpl.py deleted file mode 100644 index b053a5047..000000000 --- a/src/xtgeo/common/_import_mpl.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations - -import os -import sys -from typing import Any - -from xtgeo.common.constants import XTG_BATCH - - -def get_matplotlib() -> Any: - if "matplotlib" not in sys.modules: - import matplotlib - - if XTG_BATCH in os.environ: - matplotlib.use("Agg") - return sys.modules["matplotlib"] - - -def get_pyplot() -> Any: - if "matplotlib" not in sys.modules: - # Set `use("Agg")` - get_matplotlib() - - if "matplotlib.pyplot" not in sys.modules: - # Must be imported separately - import matplotlib.pyplot - - matplotlib.pyplot # Let flake8 ignore unused import - return sys.modules["matplotlib.pyplot"] diff --git a/src/xtgeo/common/constants.py b/src/xtgeo/common/constants.py index 2f0fe843e..751035b5c 100644 --- a/src/xtgeo/common/constants.py +++ b/src/xtgeo/common/constants.py @@ -25,5 +25,3 @@ # for XYZ data, restricted to float32 and int32 UNDEF_CONT = UNDEF UNDEF_DISC = UNDEF_INT - -XTG_BATCH = "XTG_BATCH" diff --git a/src/xtgeo/plot/baseplot.py b/src/xtgeo/plot/baseplot.py index a577a5a11..704b7becb 100644 --- a/src/xtgeo/plot/baseplot.py +++ b/src/xtgeo/plot/baseplot.py @@ -2,7 +2,6 @@ from packaging.version import parse as versionparse from xtgeo.common import XTGeoDialog, null_logger -from xtgeo.common._import_mpl import get_matplotlib, get_pyplot from . import _colortables as _ctable @@ -12,9 +11,11 @@ def _get_colormap(name): """For matplotlib compatibility.""" - mpl = get_matplotlib() + import matplotlib as mpl + if versionparse(mpl.__version__) < versionparse("3.6"): - plt = get_pyplot() + import matplotlib.plt as plt + return plt.cm.get_cmap(name) else: return mpl.colormaps[name] @@ -56,7 +57,8 @@ def colormap(self): @colormap.setter def colormap(self, cmap): - mpl = get_matplotlib() + import matplotlib as mpl + if isinstance(cmap, mpl.colors.LinearSegmentedColormap): self._colormap = cmap elif isinstance(cmap, str): @@ -86,8 +88,9 @@ def define_any_colormap(cfile, colorlist=None): from 0 index. Default is just keep the linear sequence as is. """ - mpl = get_matplotlib() - plt = get_matplotlib() + import matplotlib as mpl + import matplotlib.pyplot as plt + valid_maps = sorted(m for m in plt.cm.datad) logger.info("Valid color maps: %s", valid_maps) @@ -205,7 +208,8 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0): """ - plt = get_pyplot() + import matplotlib.pyplot as plt + self._fig, self._ax = plt.subplots( figsize=(11.69 * figscaling, 8.27 * figscaling) ) @@ -227,8 +231,9 @@ def show(self): self._fig.tight_layout() if self._showok: + import matplotlib.pyplot as plt + logger.info("Calling plt show method...") - plt = get_pyplot() plt.show() return True @@ -242,7 +247,8 @@ def close(self): After close is called, no more operations can be performed on the plot. """ - plt = get_pyplot() + import matplotlib.pyplot as plt + for fig in self._allfigs: plt.close(fig) @@ -272,7 +278,8 @@ def savefig(self, filename, fformat="png", last=True, **kwargs): self._fig.tight_layout() if self._showok: - plt = get_pyplot() + import matplotlib.pyplot as plt + plt.savefig(filename, format=fformat, **kwargs) if last: self.close() diff --git a/src/xtgeo/plot/grid3d_slice.py b/src/xtgeo/plot/grid3d_slice.py index 91a38a661..a1d809738 100644 --- a/src/xtgeo/plot/grid3d_slice.py +++ b/src/xtgeo/plot/grid3d_slice.py @@ -1,7 +1,6 @@ """Module for 3D Grid slice plots, using matplotlib.""" from xtgeo.common import null_logger -from xtgeo.common._import_mpl import get_matplotlib, get_pyplot from xtgeo.plot.baseplot import BasePlot logger = null_logger(__name__) @@ -122,7 +121,8 @@ def _plot_layer(self): patches = [] - mpl = get_matplotlib() + import matplotlib as mpl + for pos in range(len(ibn)): nppol = xyc[pos, :, :] if nppol.mean() > 0.0: @@ -155,5 +155,6 @@ def _plot_layer(self): self._ax.set_ylim((ymin, ymax)) self._fig.colorbar(im) - plt = get_pyplot() + import matplotlib.pyplot as plt + plt.gca().set_aspect("equal", adjustable="box") diff --git a/src/xtgeo/plot/xsection.py b/src/xtgeo/plot/xsection.py index 07bc5b33b..8946cc900 100644 --- a/src/xtgeo/plot/xsection.py +++ b/src/xtgeo/plot/xsection.py @@ -11,7 +11,6 @@ from scipy.ndimage import gaussian_filter from xtgeo.common import XTGeoDialog, null_logger -from xtgeo.common._import_mpl import get_matplotlib, get_pyplot from xtgeo.well import Well from xtgeo.xyz import Polygons @@ -263,7 +262,7 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0): """ # overriding the base class canvas - plt = get_pyplot() + import matplotlib.pyplot as plt plt.rcParams["axes.xmargin"] = 0 # fill the plot margins @@ -444,7 +443,8 @@ def set_xaxis_md(self, gridlines=False): md_start_round = int(math.floor(md_start / 100.0)) * 100 md_start_delta = md_start - md_start_round - plt = get_pyplot() + import matplotlib.pyplot as plt + auto_ticks = plt.xticks() auto_ticks_delta = auto_ticks[0][1] - auto_ticks[0][0] @@ -566,7 +566,8 @@ def _plot_well_zlog(self, df, ax, bba, zonelogname, logwidth=4, legend=False): df, idx_zshift, ctable, zonelogname, fillnavalue ) - mpl = get_matplotlib() + import matplotlib as mpl + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=202 ) @@ -611,7 +612,8 @@ def _plot_well_faclog(self, df, ax, bba, facieslogname, logwidth=9, legend=True) df, idx, ctable, facieslogname, fillnavalue ) - mpl = get_matplotlib() + import matplotlib as mpl + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=201 ) @@ -658,7 +660,8 @@ def _plot_well_perflog(self, df, ax, bba, perflogname, logwidth=12, legend=True) df, idx, ctable, perflogname, fillnavalue ) - mpl = get_matplotlib() + import matplotlib as mpl + lc = mpl.collections.LineCollection( segments, colors=segments_colors, linewidth=logwidth, zorder=200 ) @@ -772,7 +775,8 @@ def _drawproxylegend(self, ax, bba, items, title=None): proxies = [] labels = [] - mpl = get_matplotlib() + import matplotlib as mpl + for item in items: color = items[item] proxies.append(mpl.lines.Line2D([0, 1], [0, 1], color=color, linewidth=5)) diff --git a/src/xtgeo/plot/xtmap.py b/src/xtgeo/plot/xtmap.py index 5bd5ce3d7..cdf4cfb57 100644 --- a/src/xtgeo/plot/xtmap.py +++ b/src/xtgeo/plot/xtmap.py @@ -4,7 +4,6 @@ import numpy.ma as ma from xtgeo.common import null_logger -from xtgeo.common._import_mpl import get_matplotlib, get_pyplot from .baseplot import BasePlot @@ -118,7 +117,8 @@ def plot_surface( levels = np.linspace(minvalue, maxvalue, self.contourlevels) logger.debug("Number of contour levels: %s", levels) - plt = get_pyplot() + import matplotlib.pyplot as plt + plt.setp(self._ax.xaxis.get_majorticklabels(), rotation=xlabelrotation) # zi = ma.masked_where(zimask, zi) @@ -141,7 +141,8 @@ def plot_surface( else: logger.info("use LogLocator") - mpl = get_matplotlib() + import matplotlib as mpl + locator = mpl.ticker.LogLocator() ticks = None uselevels = None @@ -175,8 +176,9 @@ def plot_faults( .. _Matplotlib: http://matplotlib.org/api/colors_api.html """ + import matplotlib as mpl + aff = fpoly.dataframe.groupby(idname) - mpl = get_matplotlib() for name, _group in aff: # make a dataframe sorted on faults (groupname) diff --git a/src/xtgeo/surface/_regsurf_oper.py b/src/xtgeo/surface/_regsurf_oper.py index d41c9aa57..c17d5fe85 100644 --- a/src/xtgeo/surface/_regsurf_oper.py +++ b/src/xtgeo/surface/_regsurf_oper.py @@ -11,7 +11,6 @@ import xtgeo from xtgeo import XTGeoCLibError, _cxtgeo from xtgeo.common import XTGeoDialog, null_logger -from xtgeo.common._import_mpl import get_matplotlib from xtgeo.xyz import Polygons xtg = XTGeoDialog() @@ -565,7 +564,8 @@ def _proxy_map_polygons(surf, poly, inside=True): xvals, yvals = proxy.get_xy_values(asmasked=False) points = np.array([xvals.ravel(), yvals.ravel()]).T - mpl = get_matplotlib() + import matplotlib as mpl + for pol in usepolys: idgroups = pol.dataframe.groupby(pol.pname) for _, grp in idgroups: diff --git a/src/xtgeo/xyz/_xyz_oper.py b/src/xtgeo/xyz/_xyz_oper.py index 7b1a9760b..6cf030a2b 100644 --- a/src/xtgeo/xyz/_xyz_oper.py +++ b/src/xtgeo/xyz/_xyz_oper.py @@ -10,7 +10,6 @@ import xtgeo from xtgeo import _cxtgeo from xtgeo.common import XTGeoDialog, null_logger -from xtgeo.common._import_mpl import get_matplotlib xtg = XTGeoDialog() logger = null_logger(__name__) @@ -40,7 +39,8 @@ def mark_in_polygons_mpl(self, poly, name, inside_value, outside_value): self.dataframe[name] = outside_value - mpl = get_matplotlib() + import matplotlib as mpl + for pol in usepolys: idgroups = pol.dataframe.groupby(pol.pname) for _, grp in idgroups: