diff --git a/pylops/basicoperators/spread.py b/pylops/basicoperators/spread.py index fe1784f9..2244ffb6 100644 --- a/pylops/basicoperators/spread.py +++ b/pylops/basicoperators/spread.py @@ -6,10 +6,13 @@ import numpy as np from pylops import LinearOperator +from pylops.utils import deps from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray -try: +jit_message = deps.numba_import("the spread module") + +if jit_message is None: from numba import jit from ._spread_numba import ( @@ -18,12 +21,6 @@ _rmatvec_numba_onthefly, _rmatvec_numba_table, ) -except ModuleNotFoundError: - jit = None - jit_message = "Numba not available, reverting to numpy." -except Exception as e: - jit = None - jit_message = "Failed to import numba (error:%s), use numpy." % e logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -183,10 +180,10 @@ def __init__( if engine not in ["numpy", "numba"]: raise KeyError("engine must be numpy or numba") - if engine == "numba" and jit is not None: + if engine == "numba" and jit_message is None: self.engine = "numba" else: - if engine == "numba" and jit is None: + if engine == "numba" and jit is not None: logging.warning(jit_message) self.engine = "numpy" diff --git a/pylops/optimization/cls_sparsity.py b/pylops/optimization/cls_sparsity.py index 19292704..1eb31cc7 100644 --- a/pylops/optimization/cls_sparsity.py +++ b/pylops/optimization/cls_sparsity.py @@ -15,18 +15,15 @@ normal_equations_inversion, regularized_inversion, ) +from pylops.utils import deps from pylops.utils.backend import get_array_module, get_module_name from pylops.utils.decorators import disable_ndarray_multiplication from pylops.utils.typing import InputDimsLike, NDArray, SamplingLike -try: +spgl1_message = deps.spgl1_import("the spgl1 solver") + +if spgl1_message is None: from spgl1 import spgl1 as ext_spgl1 -except ModuleNotFoundError: - ext_spgl1 = None - spgl1_message = "Spgl1 not installed. " 'Run "pip install spgl1".' -except Exception as e: - ext_spgl1 = None - spgl1_message = f"Failed to import spgl1 (error:{e})." def _hardthreshold(x: NDArray, thresh: float) -> NDArray: @@ -1763,7 +1760,7 @@ def setup( Display setup log """ - if ext_spgl1 is None: + if spgl1_message is not None: raise ModuleNotFoundError(spgl1_message) self.y = y diff --git a/pylops/signalprocessing/chirpradon3d.py b/pylops/signalprocessing/chirpradon3d.py index f8884662..1f544926 100755 --- a/pylops/signalprocessing/chirpradon3d.py +++ b/pylops/signalprocessing/chirpradon3d.py @@ -5,26 +5,18 @@ import numpy as np from pylops import LinearOperator +from pylops.utils import deps from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, NDArray from ._chirpradon3d import _chirp_radon_3d -try: +pyfftw_message = deps.pyfftw_import("the chirpradon3d module") + +if pyfftw_message is None: import pyfftw from ._chirpradon3d import _chirp_radon_3d_fftw -except ModuleNotFoundError: - pyfftw = None - pyfftw_message = ( - "Pyfftw not installed, use numpy or run " - '"pip install pyFFTW" or ' - '"conda install -c conda-forge pyfftw".' - ) -except Exception as e: - pyfftw = None - pyfftw_message = f"Failed to import pyfftw (error:{e}), use numpy." - logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) diff --git a/pylops/signalprocessing/dwt.py b/pylops/signalprocessing/dwt.py index aa698c21..8785f1ac 100644 --- a/pylops/signalprocessing/dwt.py +++ b/pylops/signalprocessing/dwt.py @@ -8,21 +8,14 @@ from pylops import LinearOperator from pylops.basicoperators import Pad +from pylops.utils import deps from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray -try: +pywt_message = deps.pywt_import("the dwt module") + +if pywt_message is None: import pywt -except ModuleNotFoundError: - pywt = None - pywt_message = ( - "Pywt package not installed. " - 'Run "pip install PyWavelets" or ' - 'conda install pywavelets".' - ) -except Exception as e: - pywt = None - pywt_message = f"Failed to import pywt (error:{e})." logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -113,7 +106,7 @@ def __init__( dtype: DTypeLike = "float64", name: str = "D", ) -> None: - if pywt is None: + if pywt_message is not None: raise ModuleNotFoundError(pywt_message) _checkwavelet(wavelet) diff --git a/pylops/signalprocessing/dwt2d.py b/pylops/signalprocessing/dwt2d.py index 398b1a8a..3116cfdf 100644 --- a/pylops/signalprocessing/dwt2d.py +++ b/pylops/signalprocessing/dwt2d.py @@ -7,22 +7,15 @@ from pylops import LinearOperator from pylops.basicoperators import Pad +from pylops.utils import deps from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray from .dwt import _adjointwavelet, _checkwavelet -try: +pywt_message = deps.pywt_import("the dwt2d module") + +if pywt_message is None: import pywt -except ModuleNotFoundError: - pywt = None - pywt_message = ( - "Pywt package not installed. " - 'Run "pip install PyWavelets" or ' - 'conda install pywavelets".' - ) -except Exception as e: - pywt = None - pywt_message = f"Failed to import pywt (error:{e})." logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -90,7 +83,7 @@ def __init__( dtype: DTypeLike = "float64", name: str = "D", ) -> None: - if pywt is None: + if pywt_message is not None: raise ModuleNotFoundError(pywt_message) _checkwavelet(wavelet) diff --git a/pylops/signalprocessing/fft.py b/pylops/signalprocessing/fft.py index fbbedd7d..64444bcd 100644 --- a/pylops/signalprocessing/fft.py +++ b/pylops/signalprocessing/fft.py @@ -10,21 +10,14 @@ from pylops import LinearOperator from pylops.signalprocessing._baseffts import _BaseFFT, _FFTNorms +from pylops.utils import deps from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray -try: +pyfftw_message = deps.pyfftw_import("the fft module") + +if pyfftw_message is None: import pyfftw -except ModuleNotFoundError: - pyfftw = None - pyfftw_message = ( - "Pyfftw not installed, use numpy or run " - '"pip install pyFFTW" or ' - '"conda install -c conda-forge pyfftw".' - ) -except Exception as e: - pyfftw = None - pyfftw_message = f"Failed to import pyfftw (error:{e}), use numpy." logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -544,7 +537,7 @@ def FFT( signals. """ - if engine == "fftw" and pyfftw is not None: + if engine == "fftw" and pyfftw_message is None: f = _FFT_fftw( dims, axis=axis, @@ -557,8 +550,8 @@ def FFT( dtype=dtype, **kwargs_fftw, ) - elif engine == "numpy" or (engine == "fftw" and pyfftw is None): - if engine == "fftw" and pyfftw is None: + elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None): + if engine == "fftw" and pyfftw_message is not None: logging.warning(pyfftw_message) f = _FFT_numpy( dims, diff --git a/pylops/signalprocessing/radon2d.py b/pylops/signalprocessing/radon2d.py index 4e3d4185..69abb857 100644 --- a/pylops/signalprocessing/radon2d.py +++ b/pylops/signalprocessing/radon2d.py @@ -6,9 +6,12 @@ import numpy as np from pylops.basicoperators import Spread +from pylops.utils import deps from pylops.utils.typing import DTypeLike, NDArray -try: +jit_message = deps.numba_import("the radon2d module") + +if jit_message is None: from numba import jit from ._radon2d_numba import ( @@ -18,8 +21,6 @@ _linear_numba, _parabolic_numba, ) -except ModuleNotFoundError: - jit = None logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -246,7 +247,7 @@ def Radon2D( # engine if engine not in ["numpy", "numba"]: raise KeyError("engine must be numpy or numba") - if engine == "numba" and jit is None: + if engine == "numba" and jit_message is not None: engine = "numpy" # axes nt, nh, npx = taxis.size, haxis.size, pxaxis.size diff --git a/pylops/signalprocessing/radon3d.py b/pylops/signalprocessing/radon3d.py index 1577840b..0277b764 100644 --- a/pylops/signalprocessing/radon3d.py +++ b/pylops/signalprocessing/radon3d.py @@ -6,9 +6,12 @@ import numpy as np from pylops.basicoperators import Spread +from pylops.utils import deps from pylops.utils.typing import DTypeLike, NDArray -try: +jit_message = deps.numba_import("the radon3d module") + +if jit_message is None: from numba import jit from ._radon3d_numba import ( @@ -18,8 +21,6 @@ _linear_numba, _parabolic_numba, ) -except ModuleNotFoundError: - jit = None logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -270,7 +271,7 @@ def Radon3D( # engine if engine not in ["numpy", "numba"]: raise KeyError("engine must be numpy or numba") - if engine == "numba" and jit is None: + if engine == "numba" and jit_message is not None: engine = "numpy" # axes diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index 61fb9bd9..7fad2838 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -1,18 +1,158 @@ __all__ = [ "cupy_enabled", "cusignal_enabled", + "devito_enabled", "numba_enabled", + "pyfftw_enabled", + "pywt_enabled", + "skfmm_enabled", + "spgl1_enabled", + "sympy_enabled", "torch_enabled", ] import os from importlib import util +# check package availability cupy_enabled = ( util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1 ) cusignal_enabled = ( util.find_spec("cusignal") is not None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1 ) +devito_enabled = util.find_spec("devito") is not None numba_enabled = util.find_spec("numba") is not None +pyfftw_enabled = util.find_spec("pyfftw") is not None +pywt_enabled = util.find_spec("pywt") is not None +skfmm_enabled = util.find_spec("skfmm") is not None +spgl1_enabled = util.find_spec("spgl1") is not None +sympy_enabled = util.find_spec("sympy") is not None torch_enabled = util.find_spec("torch") is not None + + +# error message at import of available package +def devito_import(message): + if devito_enabled: + try: + import devito # noqa: F401 + + devito_message = None + except Exception as e: + devito_message = f"Failed to import devito (error:{e})." + else: + devito_message = ( + f"Devito not available. " + f"In order to be able to use " + f'{message} run "pip install devito".' + ) + return devito_message + + +def numba_import(message): + if numba_enabled: + try: + import numba # noqa: F401 + + numba_message = None + except Exception as e: + numba_message = f"Failed to import numba (error:{e}), use numpy." + else: + numba_message = ( + "Numba not available, reverting to numpy. " + "In order to be able to use " + f"{message} run " + f'"pip install numba" or ' + f'"conda install numba".' + ) + return numba_message + + +def pyfftw_import(message): + if pyfftw_enabled: + try: + import pyfftw # noqa: F401 + + pyfftw_message = None + except Exception as e: + pyfftw_message = f"Failed to import pyfftw (error:{e}), use numpy." + else: + pyfftw_message = ( + "Pyfftw not available, reverting to numpy. " + "In order to be able to use " + f"{message} run " + f'"pip install pyFFTW" or ' + f'"conda install -c conda-forge pyfftw".' + ) + return pyfftw_message + + +def pywt_import(message): + if pywt_enabled: + try: + import pywt # noqa: F401 + + pywt_message = None + except Exception as e: + pywt_message = f"Failed to import pywt (error:{e})." + else: + pywt_message = ( + "Pywt not available. " + "In order to be able to use " + f"{message} run " + f'"pip install PyWavelets" or ' + f'"conda install pywavelets".' + ) + return pywt_message + + +def skfmm_import(message): + if skfmm_enabled: + try: + import skfmm # noqa: F401 + + skfmm_message = None + except Exception as e: + skfmm_message = f"Failed to import skfmm (error:{e})." + else: + skfmm_message = ( + f"Skfmm package not installed. In order to be able to use " + f"{message} run " + f'"pip install scikit-fmm" or ' + f'"conda install -c conda-forge scikit-fmm".' + ) + return skfmm_message + + +def spgl1_import(message): + if spgl1_enabled: + try: + import spgl1 # noqa: F401 + + spgl1_message = None + except Exception as e: + spgl1_message = f"Failed to import spgl1 (error:{e})." + else: + spgl1_message = ( + f"Spgl1 package not installed. In order to be able to use " + f"{message} run " + f'"pip install spgl1".' + ) + return spgl1_message + + +def sympy_import(message): + if sympy_enabled: + try: + import sympy # noqa: F401 + + sympy_message = None + except Exception as e: + sympy_message = f"Failed to import sympy (error:{e})." + else: + sympy_message = ( + f"Sympy package not installed. In order to be able to use " + f"{message} run " + f'"pip install sympy".' + ) + return sympy_message diff --git a/pylops/utils/describe.py b/pylops/utils/describe.py index 40612318..1ed0d185 100644 --- a/pylops/utils/describe.py +++ b/pylops/utils/describe.py @@ -27,15 +27,13 @@ from pylops import LinearOperator from pylops.basicoperators import BlockDiag, HStack, VStack from pylops.linearoperator import _ScaledLinearOperator, _SumLinearOperator +from pylops.utils import deps -try: +sympy_message = deps.sympy_import("the describe module") + +if sympy_message is None: from sympy import BlockDiagMatrix, BlockMatrix, MatrixSymbol -except ModuleNotFoundError: - raise ModuleNotFoundError( - "Sympy package not installed. In order to use " - "the describe method run " - "install sympy." - ) + compositeops = ( LinearOperator, @@ -299,6 +297,9 @@ def describe(Op) -> None: Linear Operator to describe """ + if sympy_message is not None: + raise NotImplementedError(sympy_message) + # Describe the operator Ops = {} names = set() diff --git a/pylops/waveeqprocessing/twoway.py b/pylops/waveeqprocessing/twoway.py index 3f0fdfd0..9b74d726 100644 --- a/pylops/waveeqprocessing/twoway.py +++ b/pylops/waveeqprocessing/twoway.py @@ -5,23 +5,15 @@ import numpy as np from pylops import LinearOperator +from pylops.utils import deps from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray, SamplingLike -try: - import devito +devito_message = deps.devito_import("the twoway module") +if devito_message is None: from examples.seismic import AcquisitionGeometry, Model from examples.seismic.acoustic import AcousticWaveSolver -except ModuleNotFoundError: - devito = None - devito_message = ( - "Devito package not installed. In order to be able to use" - 'the twoway module run "pip install devito".' - ) -except Exception as e: - devito = None - devito_message = f"Failed to import devito (error:{e})." class AcousticWave2D(LinearOperator): @@ -96,8 +88,9 @@ def __init__( dtype: DTypeLike = "float32", name: str = "A", ) -> None: - if not devito: + if devito_message is not None: raise NotImplementedError(devito_message) + # create model self._create_model(shape, origin, spacing, vp, space_order, nbl) self._create_geometry(src_x, src_z, rec_x, rec_z, t0, tn, src_type, f0=f0)