Skip to content

Commit

Permalink
CLN: Dynamically import matplotlib and pyplot
Browse files Browse the repository at this point in the history
Normally dynamically importing modules is fairly straight forward: you
can just import them in the function or method they are used in.
However, by default matplotlib will look for display settings on the
machine it's being run on and compute cluster nodes do not have this
set.

The suggested solution to this is to import matplotlib before anything
else is imported, and set its backend to `use("Agg")` which doesn't look
for display information. xtgeo implemented this by importing matplotlib
in the root __init__, and it generally complicates the dynamic loading
situation.

This solution tries to ensure that using the Agg backend will still be
triggered if it xtgeo believes it is in batch mode and wraps a getting
around `sys.modules` after importing it. It also tries not to repeat the
import logic if it is already imported.
  • Loading branch information
mferrera committed Nov 22, 2023
1 parent 7ea1098 commit f4b48be
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 101 deletions.
26 changes: 7 additions & 19 deletions src/xtgeo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# flake8: noqa
# pylint: skip-file
# type: ignore
Expand Down Expand Up @@ -50,34 +49,23 @@ def _xprint(msg):
except Exception:
ROXAR = False


# to avoid problems in batch runs when no DISPLAY is set:
_xprint("Import matplotlib etc...")
if not ROXAR:
import matplotlib as mplib

display = os.environ.get("DISPLAY", "")
host1 = os.environ.get("HOSTNAME", "")
host2 = os.environ.get("HOST", "")
dhost = host1 + host2 + display
_display = os.environ.get("DISPLAY", "")
_hostname = os.environ.get("HOSTNAME", "")
_host = os.environ.get("HOST", "")

ertbool = "LSB_JOBID" in os.environ
_dhost = _hostname + _host + _display
_lsf_job = "LSB_JOBID" in os.environ

if display == "" or "grid" in dhost or "lgc" in dhost or ertbool:
if _display == "" or "grid" in _dhost or "lgc" in _dhost or _lsf_job:
_xprint("")
_xprint("=" * 79)

_xprint(
"XTGeo info: No display found or a batch (e.g. ERT) server. "
"Using non-interactive Agg backend for matplotlib"
)
mplib.use("Agg")
_xprint("=" * 79)

#
# Order matters!
#
_xprint("Import matplotlib etc...DONE")
os.environ["MPLBACKEND"] = "Agg"

from xtgeo._cxtgeo import XTGeoCLibError
from xtgeo.common import XTGeoDialog
Expand Down
3 changes: 0 additions & 3 deletions src/xtgeo/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""The XTGeo plot package"""


Expand All @@ -7,5 +6,3 @@
# flake8: noqa
from xtgeo.plot.xsection import XSection
from xtgeo.plot.xtmap import Map

# from ._colortables import random, random40, xtgeocolors, colorsfromfile
53 changes: 43 additions & 10 deletions src/xtgeo/plot/baseplot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""The baseplot module."""
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from packaging.version import parse as versionparse

from xtgeo.common import XTGeoDialog, null_logger
Expand All @@ -14,7 +11,11 @@

def _get_colormap(name):
"""For matplotlib compatibility."""
import matplotlib as mpl

if versionparse(mpl.__version__) < versionparse("3.6"):
import matplotlib.plt as plt

return plt.cm.get_cmap(name)
else:
return mpl.colormaps[name]
Expand Down Expand Up @@ -56,7 +57,9 @@ def colormap(self):

@colormap.setter
def colormap(self, cmap):
if isinstance(cmap, LinearSegmentedColormap):
import matplotlib as mpl

if isinstance(cmap, mpl.colors.LinearSegmentedColormap):
self._colormap = cmap
elif isinstance(cmap, str):
logger.info("Definition of a colormap from string name: %s", cmap)
Expand Down Expand Up @@ -85,6 +88,9 @@ def define_any_colormap(cfile, colorlist=None):
from 0 index. Default is just keep the linear sequence as is.
"""
import matplotlib as mpl
import matplotlib.pyplot as plt

valid_maps = sorted(m for m in plt.cm.datad)

logger.info("Valid color maps: %s", valid_maps)
Expand All @@ -99,21 +105,37 @@ def define_any_colormap(cfile, colorlist=None):

elif cfile == "xtgeo":
colors = _ctable.xtgeocolors()
cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
cfile,
colors,
N=len(colors),
)
cmap.name = "xtgeo"
elif cfile == "random40":
colors = _ctable.random40()
cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
cfile,
colors,
N=len(colors),
)
cmap.name = "random40"

elif cfile == "randomc":
colors = _ctable.randomc(256)
cmap = LinearSegmentedColormap.from_list(cfile, colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
cfile,
colors,
N=len(colors),
)
cmap.name = "randomc"

elif isinstance(cfile, str) and "rms" in cfile:
colors = _ctable.colorsfromfile(cfile)
cmap = LinearSegmentedColormap.from_list("rms", colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
"rms",
colors,
N=len(colors),
)
cmap.name = cfile
elif cfile in valid_maps:
cmap = _get_colormap(cfile)
Expand All @@ -138,7 +160,11 @@ def define_any_colormap(cfile, colorlist=None):
logger.warning("Color list out of range")
ctable.append(colors[0])

cmap = LinearSegmentedColormap.from_list(ctable, colors, N=len(colors))
cmap = mpl.colors.LinearSegmentedColormap.from_list(
ctable,
colors,
N=len(colors),
)
cmap.name = "user"

return cmap
Expand Down Expand Up @@ -182,7 +208,8 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0):
"""
# self._fig, (ax1, ax2) = plt.subplots(2, figsize=(11.69, 8.27))
import matplotlib.pyplot as plt

self._fig, self._ax = plt.subplots(
figsize=(11.69 * figscaling, 8.27 * figscaling)
)
Expand All @@ -204,6 +231,8 @@ def show(self):
self._fig.tight_layout()

if self._showok:
import matplotlib.pyplot as plt

logger.info("Calling plt show method...")
plt.show()
return True
Expand All @@ -218,6 +247,8 @@ def close(self):
After close is called, no more operations can be performed on the plot.
"""
import matplotlib.pyplot as plt

for fig in self._allfigs:
plt.close(fig)

Expand Down Expand Up @@ -247,6 +278,8 @@ def savefig(self, filename, fformat="png", last=True, **kwargs):
self._fig.tight_layout()

if self._showok:
import matplotlib.pyplot as plt

plt.savefig(filename, format=fformat, **kwargs)
if last:
self.close()
Expand Down
58 changes: 6 additions & 52 deletions src/xtgeo/plot/grid3d_slice.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
"""Module for 3D Grid slice plots, using matplotlib."""


import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon

from xtgeo.common import null_logger
from xtgeo.plot.baseplot import BasePlot

Expand Down Expand Up @@ -105,51 +100,6 @@ def plot_gridslice(
else:
self._plot_layer()

# def _plot_row(self):

# geomlist = self._geomlist

# if self._window is None:
# xmin = geomlist[3] - 0.05 * (abs(geomlist[4] - geomlist[3]))
# xmax = geomlist[4] + 0.05 * (abs(geomlist[4] - geomlist[3]))
# zmin = geomlist[7] - 0.05 * (abs(geomlist[8] - geomlist[7]))
# zmax = geomlist[8] + 0.05 * (abs(geomlist[8] - geomlist[7]))
# else:
# xmin, xmax, zmin, zmax = self._window

# # now some numpy operations, numbering is intended
# clist = self._clist
# xz0 = np.column_stack((clist[0].values1d, clist[2].values1d))
# xz1 = np.column_stack((clist[3].values1d, clist[5].values1d))
# xz2 = np.column_stack((clist[15].values1d, clist[17].values1d))
# xz3 = np.column_stack((clist[12].values1d, clist[14].values1d))

# xyc = np.column_stack((xz0, xz1, xz2, xz3))
# xyc = xyc.reshape(self._grid.nlay, self._grid.ncol * self._grid.nrow, 4, 2)

# patches = []

# for pos in range(self._grid.nrow * self._grid.nlay):
# nppol = xyc[self._index - 1, pos, :, :]
# if nppol.mean() > 0.0:
# polygon = Polygon(nppol, True)
# patches.append(polygon)

# black = (0, 0, 0, 1)
# patchcoll = PatchCollection(patches, edgecolors=(black,), cmap=self.colormap)

# # patchcoll.set_array(np.array(pvalues))

# # patchcoll.set_clim([minvalue, maxvalue])

# im = self._ax.add_collection(patchcoll)
# self._ax.set_xlim((xmin, xmax))
# self._ax.set_ylim((zmin, zmax))
# self._ax.invert_yaxis()
# self._fig.colorbar(im)

# # plt.gca().set_aspect("equal", adjustable="box")

def _plot_layer(self):
xyc, ibn = self._grid.get_layer_slice(self._index, activeonly=self._active)

Expand All @@ -171,13 +121,15 @@ def _plot_layer(self):

patches = []

import matplotlib as mpl

for pos in range(len(ibn)):
nppol = xyc[pos, :, :]
if nppol.mean() > 0.0:
polygon = Polygon(nppol)
polygon = mpl.patches.Polygon(nppol)
patches.append(polygon)

patchcoll = PatchCollection(
patchcoll = mpl.collections.PatchCollection(
patches, edgecolors=(self._linecolor,), cmap=self.colormap
)

Expand All @@ -203,4 +155,6 @@ def _plot_layer(self):
self._ax.set_ylim((ymin, ymax))
self._fig.colorbar(im)

import matplotlib.pyplot as plt

plt.gca().set_aspect("equal", adjustable="box")
22 changes: 15 additions & 7 deletions src/xtgeo/plot/xsection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
import warnings
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import numpy.ma as ma
import pandas as pd
from matplotlib import collections as mc
from matplotlib.lines import Line2D
from scipy.ndimage import gaussian_filter

from xtgeo.common import XTGeoDialog, null_logger
Expand Down Expand Up @@ -265,6 +262,7 @@ def canvas(self, title=None, subtitle=None, infotext=None, figscaling=1.0):
"""
# overriding the base class canvas
import matplotlib.pyplot as plt

plt.rcParams["axes.xmargin"] = 0 # fill the plot margins

Expand Down Expand Up @@ -445,6 +443,8 @@ def set_xaxis_md(self, gridlines=False):
md_start_round = int(math.floor(md_start / 100.0)) * 100
md_start_delta = md_start - md_start_round

import matplotlib.pyplot as plt

auto_ticks = plt.xticks()
auto_ticks_delta = auto_ticks[0][1] - auto_ticks[0][0]

Expand Down Expand Up @@ -566,7 +566,9 @@ def _plot_well_zlog(self, df, ax, bba, zonelogname, logwidth=4, legend=False):
df, idx_zshift, ctable, zonelogname, fillnavalue
)

lc = mc.LineCollection(
import matplotlib as mpl

lc = mpl.collections.LineCollection(
segments, colors=segments_colors, linewidth=logwidth, zorder=202
)

Expand Down Expand Up @@ -610,7 +612,9 @@ def _plot_well_faclog(self, df, ax, bba, facieslogname, logwidth=9, legend=True)
df, idx, ctable, facieslogname, fillnavalue
)

lc = mc.LineCollection(
import matplotlib as mpl

lc = mpl.collections.LineCollection(
segments, colors=segments_colors, linewidth=logwidth, zorder=201
)

Expand Down Expand Up @@ -656,7 +660,9 @@ def _plot_well_perflog(self, df, ax, bba, perflogname, logwidth=12, legend=True)
df, idx, ctable, perflogname, fillnavalue
)

lc = mc.LineCollection(
import matplotlib as mpl

lc = mpl.collections.LineCollection(
segments, colors=segments_colors, linewidth=logwidth, zorder=200
)

Expand Down Expand Up @@ -769,9 +775,11 @@ def _drawproxylegend(self, ax, bba, items, title=None):
proxies = []
labels = []

import matplotlib as mpl

for item in items:
color = items[item]
proxies.append(Line2D([0, 1], [0, 1], color=color, linewidth=5))
proxies.append(mpl.lines.Line2D([0, 1], [0, 1], color=color, linewidth=5))
labels.append(item)

ax.legend(
Expand Down
Loading

0 comments on commit f4b48be

Please sign in to comment.