Skip to content

Commit

Permalink
CLN: Dynamically import matplotlib and pyplot
Browse files Browse the repository at this point in the history
Normally dynamically importing modules is fairly straight forward: you
can just import them in the function or method they are used in.
However, by default matplotlib will look for display settings on the
machine it's being run on and compute cluster nodes do not have this
set.

The suggested solution to this is to import matplotlib before anything
else is imported, and set its backend to `use("Agg")` which doesn't look
for display information. xtgeo implemented this by importing matplotlib
in the root __init__, and it generally complicates the dynamic loading
situation.

This solution tries to ensure that using the Agg backend will still be
triggered if it xtgeo believes it is in batch mode and wraps a getting
around `sys.modules` after importing it. It also tries not to repeat the
import logic if it is already imported.
  • Loading branch information
mferrera committed Nov 22, 2023
1 parent 7ea1098 commit 8976250
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 111 deletions.
55 changes: 26 additions & 29 deletions src/xtgeo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,37 +51,15 @@ def _xprint(msg):
ROXAR = False


# to avoid problems in batch runs when no DISPLAY is set:
_xprint("Import matplotlib etc...")
if not ROXAR:
import matplotlib as mplib

display = os.environ.get("DISPLAY", "")
host1 = os.environ.get("HOSTNAME", "")
host2 = os.environ.get("HOST", "")
dhost = host1 + host2 + display

ertbool = "LSB_JOBID" in os.environ

if display == "" or "grid" in dhost or "lgc" in dhost or ertbool:
_xprint("")
_xprint("=" * 79)

_xprint(
"XTGeo info: No display found or a batch (e.g. ERT) server. "
"Using non-interactive Agg backend for matplotlib"
)
mplib.use("Agg")
_xprint("=" * 79)

#
# Order matters!
#
_xprint("Import matplotlib etc...DONE")

from xtgeo._cxtgeo import XTGeoCLibError
from xtgeo.common import XTGeoDialog
from xtgeo.common.constants import UNDEF, UNDEF_INT, UNDEF_INT_LIMIT, UNDEF_LIMIT
from xtgeo.common.constants import (
UNDEF,
UNDEF_INT,
UNDEF_INT_LIMIT,
UNDEF_LIMIT,
XTG_BATCH,
)
from xtgeo.common.exceptions import (
BlockedWellsNotFoundError,
DateNotFoundError,
Expand Down Expand Up @@ -178,6 +156,25 @@ def _xprint(msg):
polygons_from_wells,
)

if not ROXAR:
_display = os.environ.get("DISPLAY", "")
_hostname = os.environ.get("HOSTNAME", "")
_host = os.environ.get("HOST", "")
_dhost = _hostname + _host + _display

_lsf_queue = "LSB_JOBID" in os.environ

if _display == "" or "grid" in _dhost or "lgc" in _dhost or _lsf_queue:
_xprint("")
_xprint("=" * 79)
_xprint(
"XTGeo info: No display found or a batch (e.g. ERT) server. "
"Using non-interactive Agg backend for matplotlib"
)
_xprint("=" * 79)
os.environ[XTG_BATCH] = "1"


warnings.filterwarnings("default", category=DeprecationWarning, module="xtgeo")

_xprint("XTGEO __init__ done")
29 changes: 29 additions & 0 deletions src/xtgeo/common/_import_mpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

import os
import sys
from typing import Any

from xtgeo.common.constants import XTG_BATCH


def get_matplotlib() -> Any:
if "matplotlib" not in sys.modules:
import matplotlib

if XTG_BATCH in os.environ:
matplotlib.use("Agg")
return sys.modules["matplotlib"]


def get_pyplot() -> Any:
if "matplotlib" not in sys.modules:
# Set `use("Agg")`
get_matplotlib()

if "matplotlib.pyplot" not in sys.modules:
# Must be imported separately
import matplotlib.pyplot

matplotlib.pyplot # Let flake8 ignore unused import
return sys.modules["matplotlib.pyplot"]
2 changes: 2 additions & 0 deletions src/xtgeo/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@
# for XYZ data, restricted to float32 and int32
UNDEF_CONT = UNDEF
UNDEF_DISC = UNDEF_INT

XTG_BATCH = "XTG_BATCH"
3 changes: 0 additions & 3 deletions src/xtgeo/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""The XTGeo plot package"""


Expand All @@ -7,5 +6,3 @@
# flake8: noqa
from xtgeo.plot.xsection import XSection
from xtgeo.plot.xtmap import Map

# from ._colortables import random, random40, xtgeocolors, colorsfromfile
46 changes: 36 additions & 10 deletions src/xtgeo/plot/baseplot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""The baseplot module."""
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from packaging.version import parse as versionparse

from xtgeo.common import XTGeoDialog, null_logger
from xtgeo.common._import_mpl import get_matplotlib, get_pyplot

from . import _colortables as _ctable

Expand All @@ -14,7 +12,9 @@

def _get_colormap(name):
"""For matplotlib compatibility."""
mpl = get_matplotlib()
if versionparse(mpl.__version__) < versionparse("3.6"):
plt = get_pyplot()
return plt.cm.get_cmap(name)
else:
return mpl.colormaps[name]
Expand Down Expand Up @@ -56,7 +56,8 @@ def colormap(self):

@colormap.setter
def colormap(self, cmap):
if isinstance(cmap, LinearSegmentedColormap):
mpl = get_matplotlib()
if isinstance(cmap, mpl.colors.LinearSegmentedColormap):
self._colormap = cmap
elif isinstance(cmap, str):
logger.info("Definition of a colormap from string name: %s", cmap)
Expand Down Expand Up @@ -85,6 +86,8 @@ def define_any_colormap(cfile, colorlist=None):
from 0 index. Default is just keep the linear sequence as is.
"""
mpl = get_matplotlib()
plt = get_matplotlib()
valid_maps = sorted(m for m in plt.cm.datad)

logger.info("Valid color maps: %s", valid_maps)
Expand All @@ -99,21 +102,37 @@ def define_any_colormap(cfile, colorlist=None):

elif cfile == "xtgeo":
colors = _ctable.xtgeocolors()
cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
cfile,
colors,
N=len(colors),
)
cmap.name = "xtgeo"
elif cfile == "random40":
colors = _ctable.random40()
cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
cfile,
colors,
N=len(colors),
)
cmap.name = "random40"

elif cfile == "randomc":
colors = _ctable.randomc(256)
cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
cfile,
colors,
N=len(colors),
)
cmap.name = "randomc"

elif isinstance(cfile, str) and "rms" in cfile:
colors = _ctable.colorsfromfile(cfile)
cmap = LinearSegmentedColormap.from_list("rms", colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
"rms",
colors,
N=len(colors),
)
cmap.name = cfile
elif cfile in valid_maps:
cmap = _get_colormap(cfile)
Expand All @@ -138,7 +157,11 @@ def define_any_colormap(cfile, colorlist=None):
logger.warning("Color list out of range")
ctable.append(colors[0])

cmap = LinearSegmentedColormap.from_list(ctable, colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
ctable,
colors,
N=len(colors),
)
cmap.name = "user"

return cmap
Expand Down Expand Up @@ -182,7 +205,7 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0):
"""
# self._fig, (ax1, ax2) = plt.subplots(2, figsize=(11.69, 8.27))
plt = get_pyplot()
self._fig, self._ax = plt.subplots(
figsize=(11.69 * figscaling, 8.27 * figscaling)
)
Expand All @@ -205,6 +228,7 @@ def show(self):

if self._showok:
logger.info("Calling plt show method...")
plt = get_pyplot()
plt.show()
return True

Expand All @@ -218,6 +242,7 @@ def close(self):
After close is called, no more operations can be performed on the plot.
"""
plt = get_pyplot()
for fig in self._allfigs:
plt.close(fig)

Expand Down Expand Up @@ -247,6 +272,7 @@ def savefig(self, filename, fformat="png", last=True, **kwargs):
self._fig.tight_layout()

if self._showok:
plt = get_pyplot()
plt.savefig(filename, format=fformat, **kwargs)
if last:
self.close()
Expand Down
57 changes: 5 additions & 52 deletions src/xtgeo/plot/grid3d_slice.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
"""Module for 3D Grid slice plots, using matplotlib."""


import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon

from xtgeo.common import null_logger
from xtgeo.common._import_mpl import get_matplotlib, get_pyplot
from xtgeo.plot.baseplot import BasePlot

logger = null_logger(__name__)
Expand Down Expand Up @@ -105,51 +101,6 @@ def plot_gridslice(
else:
self._plot_layer()

# def _plot_row(self):

# geomlist = self._geomlist

# if self._window is None:
# xmin = geomlist[3] - 0.05 * (abs(geomlist[4] - geomlist[3]))
# xmax = geomlist[4] + 0.05 * (abs(geomlist[4] - geomlist[3]))
# zmin = geomlist[7] - 0.05 * (abs(geomlist[8] - geomlist[7]))
# zmax = geomlist[8] + 0.05 * (abs(geomlist[8] - geomlist[7]))
# else:
# xmin, xmax, zmin, zmax = self._window

# # now some numpy operations, numbering is intended
# clist = self._clist
# xz0 = np.column_stack((clist[0].values1d, clist[2].values1d))
# xz1 = np.column_stack((clist[3].values1d, clist[5].values1d))
# xz2 = np.column_stack((clist[15].values1d, clist[17].values1d))
# xz3 = np.column_stack((clist[12].values1d, clist[14].values1d))

# xyc = np.column_stack((xz0, xz1, xz2, xz3))
# xyc = xyc.reshape(self._grid.nlay, self._grid.ncol * self._grid.nrow, 4, 2)

# patches = []

# for pos in range(self._grid.nrow * self._grid.nlay):
# nppol = xyc[self._index - 1, pos, :, :]
# if nppol.mean() > 0.0:
# polygon = Polygon(nppol, True)
# patches.append(polygon)

# black = (0, 0, 0, 1)
# patchcoll = PatchCollection(patches, edgecolors=(black,), cmap=self.colormap)

# # patchcoll.set_array(np.array(pvalues))

# # patchcoll.set_clim([minvalue, maxvalue])

# im = self._ax.add_collection(patchcoll)
# self._ax.set_xlim((xmin, xmax))
# self._ax.set_ylim((zmin, zmax))
# self._ax.invert_yaxis()
# self._fig.colorbar(im)

# # plt.gca().set_aspect("equal", adjustable="box")

def _plot_layer(self):
xyc, ibn = self._grid.get_layer_slice(self._index, activeonly=self._active)

Expand All @@ -171,13 +122,14 @@ def _plot_layer(self):

patches = []

mpl = get_matplotlib()
for pos in range(len(ibn)):
nppol = xyc[pos, :, :]
if nppol.mean() > 0.0:
polygon = Polygon(nppol)
polygon = mpl.patches.Polygon(nppol)
patches.append(polygon)

patchcoll = PatchCollection(
patchcoll = mpl.collections.PatchCollection(
patches, edgecolors=(self._linecolor,), cmap=self.colormap
)

Expand All @@ -203,4 +155,5 @@ def _plot_layer(self):
self._ax.set_ylim((ymin, ymax))
self._fig.colorbar(im)

plt = get_pyplot()
plt.gca().set_aspect("equal", adjustable="box")
Loading

0 comments on commit 8976250

Please sign in to comment.