From 3076ee1d53ed2c722cb423c9edba8e5294aabbff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dieter=20Werthm=C3=BCller?= Date: Tue, 2 Aug 2022 22:47:12 +0200 Subject: [PATCH] CLI and I/O (#160) * CLI and I/O * Simplify deploy --- .github/workflows/linux.yml | 13 +- .github/workflows/macos_windows.yml | 6 +- .gitignore | 2 + CHANGELOG.rst | 24 ++- Makefile | 4 +- docs/api/index.rst | 1 + docs/api/io.rst | 6 + docs/manual/index.rst | 1 + docs/manual/iocli.rst | 181 ++++++++++++++++++ empymod/__init__.py | 1 + empymod/__main__.py | 123 +++++++++++++ empymod/io.py | 276 ++++++++++++++++++++++++++++ requirements-dev.txt | 1 + setup.py | 5 + tests/test_cli.py | 165 +++++++++++++++++ tests/test_io.py | 127 +++++++++++++ 16 files changed, 923 insertions(+), 13 deletions(-) create mode 100644 docs/api/io.rst create mode 100644 docs/manual/iocli.rst create mode 100644 empymod/__main__.py create mode 100644 empymod/io.py create mode 100644 tests/test_cli.py create mode 100644 tests/test_io.py diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 3a746b56..e57b43ad 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -84,7 +84,7 @@ jobs: - name: Checkout uses: actions/checkout@v2 with: - # Need to fetch more than the last commit so that setuptools_scm can + # Need to fetch more than the last commit so that setuptools-scm can # create the correct version string. If the number of commits since # the last release is greater than this, the version still be wrong. # Increase if necessary. @@ -93,7 +93,7 @@ jobs: # to be able to push to GitHub. persist-credentials: false - # Need the tags so that setuptools_scm can form a valid version number + # Need the tags so that setuptools-scm can form a valid version number - name: Fetch git tags run: git fetch origin 'refs/tags/*:refs/tags/*' @@ -116,7 +116,7 @@ jobs: conda config --set always_yes yes --set changeps1 no conda config --show-sources conda config --show - conda install ${{ matrix.case.conda }} pytest pytest-cov coveralls pytest-flake8 setuptools_scm + conda install ${{ matrix.case.conda }} pytest pytest-cov pytest-console-scripts coveralls pytest-flake8 setuptools-scm conda info -a conda list @@ -157,7 +157,7 @@ jobs: - name: Checkout uses: actions/checkout@v2 with: - # Need to fetch more than the last commit so that setuptools_scm can + # Need to fetch more than the last commit so that setuptools-scm can # create the correct version string. If the number of commits since # the last release is greater than this, the version will still be # wrong. Increase if necessary. @@ -166,7 +166,7 @@ jobs: # to be able to push to GitHub. persist-credentials: false - # Need the tags so that setuptools_scm can form a valid version number + # Need the tags so that setuptools-scm can form a valid version number - name: Fetch git tags run: git fetch origin 'refs/tags/*:refs/tags/*' @@ -178,8 +178,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install wheel - pip install -r requirements-dev.txt + pip install wheel setuptools-scm - name: Build source and wheel distributions if: github.ref == 'refs/heads/main' diff --git a/.github/workflows/macos_windows.yml b/.github/workflows/macos_windows.yml index 53e2aa73..74281ab6 100644 --- a/.github/workflows/macos_windows.yml +++ b/.github/workflows/macos_windows.yml @@ -44,7 +44,7 @@ jobs: - name: Checkout uses: actions/checkout@v2 with: - # Need to fetch more than the last commit so that setuptools_scm can + # Need to fetch more than the last commit so that setuptools-scm can # create the correct version string. If the number of commits since # the last release is greater than this, the version still be wrong. # Increase if necessary. @@ -53,7 +53,7 @@ jobs: # to be able to push to GitHub. persist-credentials: false - # Need the tags so that setuptools_scm can form a valid version number + # Need the tags so that setuptools-scm can form a valid version number - name: Fetch git tags run: git fetch origin 'refs/tags/*:refs/tags/*' @@ -76,7 +76,7 @@ jobs: conda config --set always_yes yes --set changeps1 no conda config --show-sources conda config --show - conda install numba scipy pytest setuptools_scm pytest-flake8 scooby setuptools_scm + conda install numba scipy pytest pytest-console-scripts pytest-flake8 scooby setuptools-scm conda info -a conda list diff --git a/.gitignore b/.gitignore index d85ab750..2bc36fa4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ docs/_build/ docs/api/empymod* docs/savefig/ docs/gallery/*/ +docs/my*.json +docs/my*.txt # Pytest and coverage related htmlcov diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f805545e..ea22075e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,29 @@ Version 2 ~~~~~~~~~ +v2.2.x +"""""" + + +v2.2.0: I/O & CLI +----------------- + +**2022-08-02** + +- I/O & CLI: + + - New Command-Line Interface (CLI) for the top-level modelling functions + ``bipole``, ``dipole``, ``loop``, and ``analytical``. Consult the manual + for its description, or type in your terminal ``empymod --help``. Note that + the CLI is a simple wrapper and currently lacks proper logging. + - New module ``io`` to save and load inputs and data. + +- Maintenance: + + - Improved load time by lazy-loading matplotlib and some scipy submodules. + - Removed the file ``runtests.sh``; uses ``make`` instead. + + v2.1.x """""" @@ -15,7 +38,6 @@ v2.1.4: Squeeze **2022-07-20** - - The main modelling routines ``bipole``, ``dipole``, ``loop``, and ``analytical`` take a new keyword argument ``squeeze``, which is set to ``True`` by default. If true, the output is squeezed (status quo); if false, diff --git a/Makefile b/Makefile index 096ccf3b..d682041d 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,7 @@ html-noplot: cd docs && make html-noplot html-clean: - cd docs && rm -rf api/empymod* gallery/*/ _build/ && make html + cd docs && rm -rf api/empymod* gallery/*/ _build/ my*.json my*.txt && make html preview: xdg-open docs/_build/html/index.html @@ -50,6 +50,6 @@ clean: rm -rf build/ dist/ .eggs/ empymod.egg-info/ empymod/version.py # build rm -rf */__pycache__/ */*/__pycache__/ # python cache rm -rf .coverage htmlcov/ .pytest_cache/ # tests and coverage - rm -rf docs/gallery/*/ docs/gallery/*.zip docs/_build/ docs/api/empymod* # docs + rm -rf docs/gallery/*/ docs/gallery/*.zip docs/_build/ docs/api/empymod* docs/my*.json docs/my*.txt # docs rm -rf matplotlibrc docs/savefig rm -rf filters/ examples/educational/filters/ diff --git a/docs/api/index.rst b/docs/api/index.rst index f3759876..0be3230a 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -20,6 +20,7 @@ API reference model transform utils + io fdesign tmtemod diff --git a/docs/api/io.rst b/docs/api/io.rst new file mode 100644 index 00000000..9865871c --- /dev/null +++ b/docs/api/io.rst @@ -0,0 +1,6 @@ +I/O +=== + +.. automodapi:: empymod.io + :no-inheritance-diagram: + :no-heading: diff --git a/docs/manual/index.rst b/docs/manual/index.rst index 77852fd1..04d1e880 100644 --- a/docs/manual/index.rst +++ b/docs/manual/index.rst @@ -17,6 +17,7 @@ User Guide installation usage info + iocli transforms credits citation diff --git a/docs/manual/iocli.rst b/docs/manual/iocli.rst new file mode 100644 index 00000000..1f9f2708 --- /dev/null +++ b/docs/manual/iocli.rst @@ -0,0 +1,181 @@ +I/O & CLI +######### + +Starting with ``empymod v2.2.0`` there are some basic saving and loading +routines together with a command line interface. This makes it possible to +model EM responses relatively straight forward from any other code. + + +.. _I/O: + +I/O +--- + +There are currently two saving and two loading routines, one each for inputs +and one for data. «Input» in this context means all the parameters a modelling +routine takes. The saving/loading routines are on one hand good for persistence +and reproducibility, but on the other hand also necessary for the command-line +interface (see section `CLI`_). + +``{save;load}_input`` +~~~~~~~~~~~~~~~~~~~~~ + +Let's look at a simple example. From the start, we collect the input parameters +in a dictionary instead of providing them directly to the function. + +.. ipython:: + + In [1]: import empymod + ...: import numpy as np + ...: + ...: # Define input parameters + ...: inp = { + ...: 'src': [[0, 0], [0, 1000], 250, [0, 90], 0], + ...: 'rec': [np.arange(1, 6)*2000, np.zeros(5), 300, 0, 0], + ...: 'freqtime': np.logspace(-1, 1, 3), + ...: 'depth': [0, 300, 1500, 1600], + ...: 'res': [2e14, 0.3, 1, 100, 1], + ...: } + ...: + ...: # Model it + ...: efield = empymod.bipole(**inp) + +We can now simply save this dictionary to disk via + + +.. ipython:: + + In [1]: empymod.io.save_input('myrun.json', inp) + +This will save the input parameters in the file ``myrun.json`` (you can provide +absolute or relative paths in addition to the file name). The file name must +currently include ``.json``, as it is the only backend implemented so far. The +json-file is a plain text file, so you can open it with your favourite editor. +Let's have a look at it: + +.. ipython:: + + In [1]: !cat myrun.json + +As you can see, it is basically the dictionary written as json. You can +therefore write your input parameter file with any program you want to. + +These input files can then be loaded to run the *exactly* same simulation +again. + +.. ipython:: + + In [1]: inp_loaded = empymod.io.load_input('myrun.json') + ...: efield2 = empymod.bipole(**inp_loaded) + ...: # Let's check if the result is indeed the same. + ...: print(f"Result is identical: {np.allclose(efield, efield2, atol=0)}") + + +``{save;load}_data`` +~~~~~~~~~~~~~~~~~~~~ + +These functions are to store or load data. Using the computation from above, +we can store the data in one of the following two ways, either as json or as +text file: + +.. ipython:: + + In [1]: empymod.io.save_data('mydata.json', efield) + ...: empymod.io.save_data('mydata.txt', efield, info='some data info') + + +Let's have a look at the text file: + +.. ipython:: + + In [1]: !cat mydata.txt + +First is a header with the date, the version of empymod with which it was +generated, and the shape and type of the data. This is followed by a line of +additional information, which you can define by providing a string to the input +parameter ``info``. The columns are the sources (two in this case), and in the +rows there are first all receivers for the first frequency (or time), then all +receivers for the second frequency (or time), and so on. + +The json file is very similar. Here we print just the first twenty lines as an +example: + +.. ipython:: + + In [1]: !head -n 20 mydata.json + +The main difference, besides the structure, is that the json-format does not +support complex data. It lists therefore first all real parts, and then all +imaginary parts. If you load it with another json reader it will therefore +have the dimension ``(2, nfreqtime, nrec, nsrc)``, where the 2 stands for real +and imaginary parts. (Only for frequency-domain data of course, not for +time-domain data.) + +To load it in Python simply use the corresponding functions: + +.. ipython:: + + In [1]: efield_json = empymod.io.load_data('mydata.json') + ...: efield_txt = empymod.io.load_data('mydata.txt') + ...: # Let's check they are the same as the original. + ...: print(f"Json-data: {np.allclose(efield, efield_json, atol=0)}") + ...: print(f"Txt-data : {np.allclose(efield, efield_txt, atol=0)}") + + +Caution +~~~~~~~ + +There is a limitation to the ``save_input``-functionality: The data *must* be +three dimensional, ``(nfreqtime, nrec, nsrc)``. Now, in the above example that +is the case, we have 3 frequencies, 5 receivers, and 2 sources. However, if any +of these three quantities would be 1, empymod would by default squeeze the +dimension. To avoid this, you have to pass the keyword ``squeeze=False`` to the +empymod-routine. + + +.. _CLI: + +CLI +--- + +The command-line interface is a simple wrapper for the main top-level modelling +routines :func:`empymod.model.bipole`, :func:`empymod.model.dipole`, +:func:`empymod.model.loop`, and :func:`empymod.model.analytical`. To call it +you must write a json-file containing all your input parameters as described in +the section `I/O`_. The basic syntax of the CLI is + +.. code-block:: console + + empymod + +You can find some description as well by running the regular help + +.. code-block:: console + + empymod --help + +As an example, to reproduce the example given above in the I/O-section, run + +.. code-block:: console + + empymod bipole myrun.json mydata.txt + +If you do not specify the output file (the last argument) the result will be +printed to the STDOUT. + + +Warning re runtime +~~~~~~~~~~~~~~~~~~ + +A warning with regards to runtime: The CLI has an overhead, as it has to load +Python and empymod with all its dependencies each time (which is cached if +running in Python). Currently, the overhead should be less than 1s, and it will +come down further with changes happening in the dependencies. For doing some +simple forward modelling that should not be significant. However, it would +potentially be a bad idea to use the CLI for a forward modelling kernel in an +inversion. The inversion would spend a significant if not most of its time +starting Python and importing empymod over and over again. + +Consult the following issue if you are interested in the overhead and its +status: `github.com/emsig/empymod/issues/162 +`_. diff --git a/empymod/__init__.py b/empymod/__init__.py index fef7100a..8e664a68 100644 --- a/empymod/__init__.py +++ b/empymod/__init__.py @@ -15,6 +15,7 @@ # the License. # Import all modules +from empymod import io from empymod import model from empymod import utils from empymod import kernel diff --git a/empymod/__main__.py b/empymod/__main__.py new file mode 100644 index 00000000..a6059697 --- /dev/null +++ b/empymod/__main__.py @@ -0,0 +1,123 @@ +""" +Entry point for the command-line interface (CLI). +""" +# Copyright 2016-2022 The emsig community. +# +# This file is part of empymod. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy +# of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + + +import sys +import argparse + +from empymod import io, model, utils + + +def main(args=None): + """Parsing command line inputs of CLI interface.""" + + # If not explicitly called, catch arguments + if args is None: + args = sys.argv[1:] + + # Start CLI-arg-parser and define arguments + parser = argparse.ArgumentParser( + description="3D electromagnetic modeller for 1D VTI media." + ) + + # arg: Modelling routine name + parser.add_argument( + "routine", + nargs="?", + type=str, + choices=['bipole', 'dipole', 'loop', 'analytical'], + help=("name of the modelling routine") + ) + + # arg: Input file name + parser.add_argument( + "input", + nargs="?", + type=str, + help="input file name" + ) + + # arg: Output file name + parser.add_argument( + "output", + nargs="?", + default=None, + type=str, + help="output file name; prints to STDOUT if not provided" + ) + + # arg: Report + parser.add_argument( + "--report", + action="store_true", + default=False, + help="show the empymod report and exit" + ) + + # arg: Version + parser.add_argument( + "--version", + action="store_true", + default=False, + help="show the empymod version and exit" + ) + + # Get command line arguments. + args_dict = vars(parser.parse_args(args)) + + # empymod version info. + if args_dict.pop('version'): # empymod version info. + print(f"empymod v{utils.__version__}") + + # empymod report. + elif args_dict.pop('report'): + print(utils.Report()) + + # Info if not at list routine and input provided. + elif len(sys.argv) < 3: + print(f"{parser.description}\n=> Type `empymod --help` for " + f"more info (empymod v{utils.__version__}).") + + # Actually compute. + else: + try: + run(args_dict) + except (AttributeError, TypeError, ValueError, FileNotFoundError) as e: + return e + + +def run(args_dict): + """Run empymod with provided arguments.""" + + # Run empymod, enforce ``squeeze=False``. + iname = args_dict['input'] + fct = args_dict['routine'] + out = getattr(model, fct)(**{**io.load_input(iname), 'squeeze': False}) + + # Store or print result. + outfile = args_dict.pop('output') + if outfile: + info = f"Generated with from input <{iname}>." + io.save_data(outfile, out, info=info) + else: + print(out) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/empymod/io.py b/empymod/io.py new file mode 100644 index 00000000..aaace861 --- /dev/null +++ b/empymod/io.py @@ -0,0 +1,276 @@ +""" +Utility functions for writing and reading inputs and data. +""" +# Copyright 2016-2022 The emsig community. +# +# This file is part of empymod. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy +# of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +import re +import os +import json +import time + +import numpy as np + + +from empymod import utils + +__all__ = ["save_input", "load_input", "save_data", "load_data"] + + +def save_input(fname, data, **kwargs): + """Save input dict to disk. + + Save the input provided to an empymod modelling routine on disk. + + + Parameters + ---------- + fname : str + File name with absolute or relative path including suffix, which + defines the used data format. Implemented is currently only ``.json``. + + data : dict + Dictionary containing the parameters with their corresponding values + for an empymod modelling routine. + + kwargs : optional + Passed through to the saving method. + + """ + + # Ensure fname is absolute. + fname = os.path.abspath(fname) + + # Save JSON + if fname.endswith(".json"): + + # For brevity yet readability, we create our custom formatted json, + # where each model parameter is on a new line. + out = "{" + for k, v in data.items(): + out += "\n " + out += json.dumps({k: v}, cls=_ComplexNumPyEncoder)[1:-1] + out += "," + out = out[:-1] + "\n}" + + # Write it to disk. + with open(fname, "w") as f: + f.write(out) + + # Unknown, throw error + else: + raise ValueError(f"Unknown extension '.{fname.split('.')[-1]}'.") + + +def load_input(fname): + """Load input from file. + + + Parameters + ---------- + fname : str + File name with absolute or relative path including suffix, which + defines the used data format. Implemented is currently only ``.json``. + + + Returns + ------- + data : dict + Dictionary containing the input that was stored in the file. + + """ + + # Ensure fname is absolute. + fname = os.path.abspath(fname) + + # Save JSON + if fname.endswith(".json"): + with open(fname, "r", encoding="utf-8") as f: + data = json.load(f) + + # Unknown, throw error + else: + raise ValueError(f"Unknown extension '.{fname.split('.')[-1]}'.") + + return data + + +def save_data(fname, data, **kwargs): + """Save results from empymod. + + + Parameters + ---------- + fname : str + File name with absolute or relative path including suffix, which + defines the used data format. Implemented are currently: + + - ``.txt``: Uses numpy to store data to a plain text file. + - ``.json``: Uses json to store inputs to a plain text file. + + data : ndarray + The output from an empymod modelling routine. + Note: You must set ``squeeze=False`` when calling the modelling + routine, to obtain a 3D array (in case any of ``src``, ``rec``, or + ``freqtime`` has only one entry). + + info : str, default: "" + Information (one-line) to put into the header. + + kwargs : optional + Passed through to the saving method. + + """ + # Ensure the right dimensionality. + if data.ndim != 3: + raise ValueError( + "Data must be 3D (nfreqtime, nrec, nsrc); provided dimensions: " + f"{data.ndim}. You can achieve this by providing " + "``squeeze=False`` to the modelling routine." + ) + + # Ensure fname is absolute. + fname = os.path.abspath(fname) + + # Collect meta information. + shape = data.shape + meta = { + "date": f"{time.strftime('%a %b %d %H:%M:%S %Y %Z')}", + "version": f"empymod v{utils.__version__}", + "shape": str(shape), + "dtype": str(data.dtype), + "info": kwargs.pop("info", "") + } + + # Save txt with NumPy. + if fname.endswith(".txt"): + + # Define format (depends if complex). + crfmt = "%+.18e" + if np.iscomplexobj(data): + crfmt += "%+.18ej" + + # Formatting and setting. + fmt = (shape[2]*(f"{crfmt}, "))[:-2] + settings = {"delimiter": ", ", "fmt": fmt, "encoding": "utf-8"} + + with open(fname, "w", encoding="utf-8") as f: + + # Write meta data. + for k, v in meta.items(): + f.write(f"# {k}:{' '+v if v else ''}\n") + + # write data. + np.savetxt(f, X=data.reshape((-1, shape[2])), header="data", + **{**settings, **kwargs}) + + # Save JSON + elif fname.endswith(".json"): + + with open(fname, "w", encoding="utf-8") as f: + json.dump({**meta, 'data': data}, f, cls=_ComplexNumPyEncoder, + **{"indent": 2, **kwargs}) + + # Unknown, throw error + else: + raise ValueError(f"Unknown extension '.{fname.split('.')[-1]}'.") + + +def load_data(fname): + """Load results from empymod stored with ``save_data``. + + + Parameters + ---------- + fname : str + File name with absolute or relative path including suffix, which + defines the used data format. Implemented are currently: + + - ``.txt``: Plain text file, loaded with np.loadtxt; + - ``.json``: JSON plain text file. + + + Returns + ------- + EM : EMArray, (nfreqtime, nrec, nsrc) + EM data. + + """ + + # Ensure fname is absolute. + fname = os.path.abspath(fname) + + # Load txt with NumPy. + if fname.endswith(".txt"): + + # Read header for shape and dtype. + meta = {} + with open(fname, "r") as f: + for line in f: + if "data" in line: + break + (key, val) = line.split(':', maxsplit=1) + meta[key.lstrip('# ')] = val.lstrip(' ').rstrip() + strshape = re.split(r'\(|\)', meta['shape'])[1] + shape = tuple(map(int, strshape.split(","))) + + args = {"delimiter": ",", "dtype": meta['dtype'], "encoding": "utf-8"} + data = np.loadtxt(fname, **args).reshape(shape) + + # Load JSON + elif fname.endswith(".json"): + + # Load data. + with open(fname, "r", encoding="utf-8") as f: + inpdat = json.load(f) + + # If complex, re-create complex data. + data = np.array(inpdat['data']) + if 'complex' in inpdat['dtype']: + data = data[0, ...] + 1j*data[1, ...] + + # Unknown, throw error + else: + raise ValueError(f"Unknown extension '.{fname.split('.')[-1]}'.") + + return utils.EMArray(data) + + +class _ComplexNumPyEncoder(json.JSONEncoder): + """Custom json-encoder for NumPy, including complex data.""" + + def default(self, obj): + """Check if complex or NumPy, else pass on.""" + + # If complex, stack [real, imag]. + if np.iscomplexobj(obj): + obj = np.stack([np.asarray(obj).real, np.asarray(obj).imag]) + + # Convert NumPy integers + if isinstance(obj, np.integer): + return int(obj) + # Convert NumPy floats + if isinstance(obj, np.floating): + return float(obj) + # Convert NumPy booleans + if isinstance(obj, np.bool_): + return bool(obj) + # Convert NumPy arrays (includes complex) + if isinstance(obj, np.ndarray): + return obj.tolist() + + # Let the base class default method raise the TypeError. + return json.JSONEncoder.default(self, obj) diff --git a/requirements-dev.txt b/requirements-dev.txt index 85aeab68..4ce66114 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -27,3 +27,4 @@ coveralls pytest_cov pytest_mpl pytest_flake8 +pytest-console-scripts diff --git a/setup.py b/setup.py index 07e024a7..2ef0ee53 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,11 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", ], + entry_points={ + "console_scripts": [ + "empymod=empymod.__main__:main", + ], + }, python_requires=">=3.7", install_requires=[ "scipy>=1.4", diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..e8cddd16 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,165 @@ +import os +import pytest +import numpy as np +from os.path import join +from numpy.testing import assert_allclose +from contextlib import ContextDecorator + +import empymod +from empymod.__main__ import run + + +class disable_numba(ContextDecorator): + """Context decorator to disable-enable JIT and remove log file.""" + def __enter__(self): + os.environ["NUMBA_DISABLE_JIT"] = "1" + return self + + def __exit__(self, *exc): + os.environ["NUMBA_DISABLE_JIT"] = "0" + return False + + +@disable_numba() +@pytest.mark.script_launch_mode('subprocess') +def test_main(script_runner): + + # help + for inp in ['--help', '-h']: + ret = script_runner.run('empymod', inp) + assert ret.success + assert "3D electromagnetic modeller for 1D VTI media" in ret.stdout + + # info + ret = script_runner.run('empymod') + assert ret.success + assert "3D electromagnetic modeller for 1D VTI media." in ret.stdout + assert "empymod v" in ret.stdout + + # report + ret = script_runner.run('empymod', '--report') + assert ret.success + # Exclude time to avoid errors. + # Exclude empymod-version (after 300), because if run locally without + # having empymod installed it will be "unknown" for the __main__ one. + assert empymod.utils.Report().__repr__()[115:300] in ret.stdout + + # version -- VIA empymod/__main__.py by calling the folder empymod. + ret = script_runner.run('python', 'empymod', '--version') + assert ret.success + assert "empymod v" in ret.stdout + + # Wrong function -- VIA empymod/__main__.py by calling the file. + ret = script_runner.run( + 'python', join('empymod', '__main__.py'), 'wrong') + assert not ret.success + assert "error: argument routine: invalid choice: 'wrong'" in ret.stderr + + # try to run + ret = script_runner.run( + 'empymod', 'bipole', 'test.json', 'output.txt') + assert not ret.success + assert "No such file or directory" in ret.stderr + + +class TestRun: + + def test_bipole_txt(self, tmpdir): + + inp = { + 'src': [0, 0, 0, 0, 0], + 'rec': [100, 50, 10, 0, 0], + 'depth': [-20, 20], + 'res': [2e14, 1, 100], + 'freqtime': 0.01, + 'htarg': {'dlf': 'wer_201_2018', 'pts_per_dec': -1}, + 'msrc': True, + 'mrec': True, + 'signal': None, + 'strength': np.pi, + 'srcpts': 5, + } + empymod.io.save_input(join(tmpdir, 't.json'), inp) + + args_dict = { + 'routine': 'bipole', + 'input': join(tmpdir, 't.json'), + 'output': join(tmpdir, 'out.txt') + } + run(args_dict) + out = empymod.io.load_data(join(tmpdir, 'out.txt')) + assert_allclose(out, empymod.bipole(**inp)) + + def test_dipole_stdout(self, tmpdir, capsys): + + inp = { + 'src': [0, 0, 0], + 'rec': [100, 50, 10], + 'depth': [-20, 20], + 'res': [2e14, 1, 100], + 'ab': 12, + 'freqtime': 10, + 'verb': 1, + } + empymod.io.save_input(join(tmpdir, 't.json'), inp) + + args_dict = { + 'routine': 'dipole', + 'input': join(tmpdir, 't.json'), + 'output': None + } + _, _ = capsys.readouterr() + run(args_dict) + out, _ = capsys.readouterr() + out = complex(out.strip().strip("[").strip("]")) + assert_allclose(out, empymod.dipole(**inp)) + + def test_loop_txt(self, tmpdir): + + inp = { + 'src': [0, 0, 0, 0, 0], + 'rec': [100, 50, 10, 0, 0], + 'depth': [-20, 20], + 'res': [2e14, 1, 100], + 'freqtime': 0.01, + } + empymod.io.save_input(join(tmpdir, 't.json'), inp) + + args_dict = { + 'routine': 'loop', + 'input': join(tmpdir, 't.json'), + 'output': join(tmpdir, 'out.txt') + } + run(args_dict) + out = empymod.io.load_data(join(tmpdir, 'out.txt')) + assert_allclose(out, empymod.loop(**inp)) + + def test_analytical_json(self, tmpdir): + + inp = { + 'src': [0, 0, 0], + 'rec': [100, 50, 10], + 'res': np.pi, + 'freqtime': np.pi, + } + empymod.io.save_input(join(tmpdir, 't.json'), inp) + + args_dict = { + 'routine': 'analytical', + 'input': join(tmpdir, 't.json'), + 'output': join(tmpdir, 'out.json') + } + run(args_dict) + out = empymod.io.load_data(join(tmpdir, 'out.json')) + assert_allclose(out, empymod.analytical(**inp)) + + def test_failure(self, tmpdir): + + args_dict = { + 'routine': 'bipole', + 'input': join(tmpdir, 't.json'), + 'output': join(tmpdir, 'out.json') + } + + with pytest.raises(FileNotFoundError, match="t.json'"): + run(args_dict) diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 00000000..17967df1 --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,127 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose + +import empymod +from empymod import io + + +class TestSaveLoadInput: + + def test_basic(self, tmpdir): + + inp = { + "src": [[0, 0], [0, 1000], -250, 45, 0], + "rec": [[1000, 2000, 3000], [0, 0, 0], -300, 0, 0], + "depth": [0, 300, 1000, 1200], + "res": [2e14, 0.3, 1, 50, 1], + "freqtime": [0.1, 1, 10, 100], + "signal": None, + "msrc": True, + "htarg": {"pts_per_dec": -1}, # <= in json forbidden. + } + + io.save_input(tmpdir+'/test.json', data=inp) + out = io.load_input(tmpdir+'/test.json') + + # Won't work with, e.g., np-arrays + # (the comparison; {save;load}_input does work). + assert inp == out + + # Dummy check by comparing the produced result from the two inputs. + assert_allclose(empymod.bipole(**inp), empymod.bipole(**out), 0, 0) + + def test_errors(self, tmpdir): + + with pytest.raises(ValueError, match="Unknown extension '.abc'"): + io.save_input(tmpdir+'/test.abc', data=1) + + with pytest.raises(ValueError, match="Unknown extension '.abc'"): + io.load_input(tmpdir+'/test.abc') + + +class TestSaveLoadData: + + inp = { + 'src': ([0, 111, 1111], [0, 0, 0], 250), + 'rec': [np.arange(1, 8)*1000, np.zeros(7), 300], + 'depth': [0, 300], + 'res': [2e14, 0.3, 1], + 'htarg': {'pts_per_dec': -1}, + 'verb': 1, + } + + @pytest.mark.parametrize( + "extra", [{'signal': None, 'freqtime': [0.1, 1.0]}, + {'signal': 0, 'freqtime': [1.0, 5.0]}, + {'signal': None, 'freqtime': [-0.1, -1.0]}] + ) + def test_basic(self, tmpdir, extra): + + # Compute + orig = empymod.dipole(**self.inp, freqtime=[0.1, 1, 10, 100]) + + # Save + io.save_data(tmpdir+'test.txt', orig, info='Additional info') + io.save_data(tmpdir+'test.json', orig, info='Additional info') + + # Load + orig_txt = io.load_data(tmpdir+'test.txt') + orig_json = io.load_data(tmpdir+'test.json') + + # Compare numbers + assert_allclose(orig, orig_txt) + assert_allclose(orig, orig_json) + + # Ensure some header things + + for ending in ['txt', 'json']: + with open(tmpdir+'test.'+ending, 'r') as f: + text = f.read() + + assert 'date' in text + assert 'empymod v' in text + assert 'shape' in text + assert '(4, 7, 3)' in text + assert str(orig.dtype) in text + assert 'Additional info' in text + + def test_text(self, tmpdir): + + # Compute + orig = empymod.dipole(**self.inp, freqtime=[0.1, 1, 10, 100]) + + # Save + io.save_data(tmpdir+'test.txt', orig) + io.save_data(tmpdir+'test.json', orig) + + def test_errors(self, tmpdir): + + with pytest.raises(ValueError, match="must be 3D"): + io.save_data(tmpdir+'/test.json', data=np.ones((1, 1))) + + with pytest.raises(ValueError, match="Unknown extension '.abc'"): + io.save_data(tmpdir+'/test.abc', data=np.ones((1, 1, 1))) + + with pytest.raises(ValueError, match="Unknown extension '.abc'"): + io.load_data(tmpdir+'/test.abc') + + +def test_ComplexNumPyEncoder(): + + test = io._ComplexNumPyEncoder() + + # NumPy types + assert type(test.default(np.int_(1))) is int + assert type(test.default(np.float_(1))) is float + assert type(test.default(np.bool_(1))) is bool + assert type(test.default(np.array([[1., 1.], [1., 1.]]))) is list + + # Complex values + cplx = test.default(np.array([[[1+1j]]])) + assert type(cplx) is list + assert type(cplx[0][0][0][0]) is float + + # Error + with pytest.raises(TypeError, match="Object of type module"): + test.default(io)