Skip to content

Commit

Permalink
Add test to check CLI --plot argument generates plots. Other small ch…
Browse files Browse the repository at this point in the history
…anges such as adding docstrings
  • Loading branch information
richypitman committed Dec 20, 2023
1 parent c79ff40 commit 1c6a913
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 28 deletions.
42 changes: 31 additions & 11 deletions pyscal/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,14 @@ def get_satnum_from_tag(string: str) -> int:

def get_plot_config_options(curve_type: str, **kwargs) -> dict:
"""
Get config data from plot config dictionary based on the curve type
Get config data from plot config dictionary based on the curve (model) type.
Args:
curve_type (str): _description_
curve_type (str): Name of the curve type. Allowed types are given in
the PLOT_CONFIG_OPTIONS dictionary
Returns:
dict: _description_
dict: Config parameters for the chosen model type
"""

config = PLOT_CONFIG_OPTIONS[curve_type].copy()
Expand Down Expand Up @@ -306,13 +307,16 @@ def save_figure(
plot_type: str,
outdir: str,
) -> None:
"""_summary_
"""
Save the provided figure.
Args:
fig (plt.Figure): Figure to be saved
satnum (int): SATNUM number
config (dict): Plot config
plot_type (str): Figure type. Allowed types are 'relperm' and 'pc'
outdir (str): Directory where the figure will be saved
"""

# Get curve name
Expand All @@ -333,18 +337,23 @@ def save_figure(
bbox_inches="tight",
)

print(f"Figure saved to {fout}.png")

# Clear figure so that it is empty for the next SATNUM's plot
fig.clear()


def wog_plotter(model: WaterOilGas, **kwargs) -> None:
"""_summary_
"""
Plot a WaterOilGas (WaterOil and GasOil) model.
For a WaterOilGas instance, the WaterOil and GasOil instances can be
accessed, then the "table" instance variable.
Args:
model (WaterOilGas): _description_
model (WaterOilGas): WaterOilGas instance
"""

outdir = kwargs["outdir"]
Expand Down Expand Up @@ -390,11 +399,13 @@ def wog_plotter(model: WaterOilGas, **kwargs) -> None:
def wo_plotter(model: WaterOil, **kwargs) -> None:
"""
Plot a WaterOil model.
For a WaterOil instance, the saturation table can be accessed using the
"table" instance variable.
Args:
model (WaterOil): _description_
model (WaterOil): WaterOil instance
"""
config = get_plot_config_options("WaterOil", **kwargs)
satnum = get_satnum_from_tag(model.tag)
Expand All @@ -419,11 +430,13 @@ def wo_plotter(model: WaterOil, **kwargs) -> None:
def go_plotter(model: GasOil, **kwargs) -> None:
"""
Plot a GasOil model.
For a GasOil instance, the saturation table can be accessed using the
"table" instance variable.
Args:
model (GasOil): _description_
model (GasOil): GasOil instance
"""

config = get_plot_config_options("GasOil", **kwargs)
Expand All @@ -450,9 +463,16 @@ def go_plotter(model: GasOil, **kwargs) -> None:


def gw_plotter(model: GasWater, **kwargs) -> None:
# For GasWater, the format is different, and an additional formatting step is
# required. Use the formatted table as an argument to the plotter function,
# instead of the "table" instance variable
"""
For GasWater, the format is different, and an additional formatting step is
required. Use the formatted table as an argument to the plotter function,
instead of the "table" instance variable
Args:
model (GasWater): GasWater instance
"""

table = format_gaswater_table(model)
config = get_plot_config_options("GasWater", **kwargs)
satnum = get_satnum_from_tag(model.tag)
Expand Down
1 change: 0 additions & 1 deletion pyscal/pyscalcli.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,3 @@ def pyscal_main(

if plot:
plotting.plotter(wog_list, plot_pc, plot_semilog, plot_outdir)
print(f"Plots saved in {plot_outdir}")
39 changes: 23 additions & 16 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Test the plotting module"""

from pathlib import Path

import matplotlib.pyplot as plt
Expand All @@ -7,6 +9,7 @@


def test_get_satnum_from_tag():
"""Check that the SATNUM number can be retrieved from the model tag"""
# Several PyscalLists of different model types to be checked
pyscal_lists = [
PyscalList(
Expand Down Expand Up @@ -48,19 +51,23 @@ def test_get_satnum_from_tag():


def test_plotter():
# Check if Exception is raised if a model type is not included. This is done
# to check that all models have been implemented in the plotting module.
"""Check if an Exception is raised if a model type is not included. This is done
to check that all models have been implemented in the plotting module."""

class DummyPyscalList:
# Can't use the actual PyscalList, as this will raise its own exception
# (DummyModel is not a pyscal object), so a dummy PyscalList is used
"""
Can't use the actual PyscalList, as this will raise its own exception
(DummyModel is not a pyscal object), so a dummy PyscalList is used
#If the PyscalList.pyscal_list instance variable name changes, this
will still pass..."""

# If the PyscalList.pyscal_list instance variable name changes, this
# will still pass...
def __init__(self, models: list) -> None:
self.pyscal_list = models

class DummyModel:
"""Dummy model"""

def __init__(self, tag: str) -> None:
self.tag = tag

Expand All @@ -77,16 +84,16 @@ def __init__(self, tag: str) -> None:


def test_pyscal_list_attr():
# Check that the PyscalList class has an pyscal_list instance variable.
# This is access by the plotting module to loop through models to plot.
"""Check that the PyscalList class has an pyscal_list instance variable.
This is accessed by the plotting module to loop through models to plot."""
assert (
hasattr(PyscalList(), "pyscal_list") is True
), "The PyscalList object should have a pyscal_list instance variable.\
This is accessed by the plotting module."


def test_plot_relperm():
# Test that a matplotlib.pyplot Figure instance is returned
"""Test that a matplotlib.pyplot Figure instance is returned"""
wateroil = WaterOil(swl=0.1, h=0.1)
wateroil.add_corey_water()
wateroil.add_corey_oil()
Expand All @@ -101,7 +108,7 @@ def test_plot_relperm():


def test_plot_pc():
# Test that a matplotlib.pyplot Figure instance is returned
"""Test that a matplotlib.pyplot Figure instance is returned"""
wateroil = WaterOil(swl=0.1, h=0.1)
wateroil.add_corey_water()
wateroil.add_corey_oil()
Expand All @@ -117,7 +124,7 @@ def test_plot_pc():


def test_wog_plotter(tmpdir):
# Test if relative permeability figures are created by the plotter function
"""Test that relative permeability figures are created by the plotter function"""
wateroil = WaterOil(swl=0.1, h=0.1, tag="SATNUM 1")
wateroil.add_corey_water()
wateroil.add_corey_oil()
Expand Down Expand Up @@ -145,7 +152,7 @@ def test_wog_plotter(tmpdir):


def test_wo_plotter(tmpdir):
# Test if relative permeability figures are created by the plotter function
"""Test that relative permeability figures are created by the plotter function"""
wateroil = WaterOil(swl=0.1, h=0.1, tag="SATNUM 1")
wateroil.add_corey_water()
wateroil.add_corey_oil()
Expand All @@ -163,7 +170,7 @@ def test_wo_plotter(tmpdir):


def test_wo_plotter_relperm_only(tmpdir):
# Test if relative permeability figures are created by the plotter function
"""Test that relative permeability figures are created by the plotter function"""
wateroil = WaterOil(swl=0.1, h=0.1, tag="SATNUM 1")
wateroil.add_corey_water()
wateroil.add_corey_oil()
Expand All @@ -183,7 +190,7 @@ def test_wo_plotter_relperm_only(tmpdir):


def test_go_plotter(tmpdir):
# Test if relative permeability figures are created by the plotter function
"""Test that relative permeability figures are created by the plotter function"""
gasoil = GasOil(swl=0.1, h=0.1, tag="SATNUM 1")
gasoil.add_corey_gas()
gasoil.add_corey_oil()
Expand All @@ -204,7 +211,7 @@ def test_go_plotter(tmpdir):


def test_gw_plotter(tmpdir):
# Test if relative permeability figures are created by the plotter function
"""Test that relative permeability figures are created by the plotter function"""
gaswater = GasWater(swl=0.1, h=0.1, tag="SATNUM 1")
gaswater.add_corey_water()
gaswater.add_corey_gas()
Expand All @@ -222,7 +229,7 @@ def test_gw_plotter(tmpdir):


def test_save_figure(tmpdir):
# Test that figure is saved
"""Test that figure is saved"""
fig = plt.Figure()

config = {"curves": "dummy", "suffix": ""}
Expand Down
41 changes: 41 additions & 0 deletions tests/test_pyscalcli.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,44 @@ def test_pyscal_main():

with pytest.raises(ValueError, match="Interpolation parameter provided"):
pyscalcli.pyscal_main(relperm_file, int_param_wo=-1, output=os.devnull)


def test_pyscalcli_plot(capsys, mocker, tmpdir):
"""Test that plots are created through the CLI. This is done by testing to
see if the print statements in the save_figure function are present in stdout"""
scalrec_file = Path(__file__).absolute().parent / "data/scal-pc-input-example.xlsx"

mocker.patch(
"sys.argv",
[
"pyscal",
str(scalrec_file),
"--int_param_wo",
"0",
"--output",
"-",
"--plot",
"--plot_pc",
"--plot_outdir",
str(tmpdir),
],
)

pyscalcli.main()

expected_plots = [
"krw_krow_SATNUM_1.png",
"krg_krog_SATNUM_1.png",
"krw_krow_SATNUM_2.png",
"krg_krog_SATNUM_2.png",
"krw_krow_SATNUM_3.png",
"krg_krog_SATNUM_3.png",
"pcow_SATNUM_1.png",
"pcow_SATNUM_2.png",
"pcow_SATNUM_3.png",
]

captured = capsys.readouterr()

for plot in expected_plots:
assert plot in captured.out

0 comments on commit 1c6a913

Please sign in to comment.