Skip to content

Commit

Permalink
Use tmp_path for storing plots when testing gaussian model example
Browse files Browse the repository at this point in the history
  • Loading branch information
gutzbenj committed Nov 9, 2023
1 parent 9574e13 commit a9ce842
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
17 changes: 10 additions & 7 deletions example/observations_station_gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" # Noqa:D205,D400
import logging
import os
from pathlib import Path
from typing import Tuple

import matplotlib.pyplot as plt
Expand All @@ -30,6 +31,8 @@
DwdObservationResolution,
)

HERE = Path(__file__).parent

log = logging.getLogger()

try:
Expand Down Expand Up @@ -63,7 +66,7 @@ class ModelYearlyGaussians:
"""

def __init__(self, station_data: StationsResult):
def __init__(self, station_data: StationsResult, plot_path: Path):
self._station_data = station_data

result_values = station_data.values.all().df.drop_nulls()
Expand All @@ -81,7 +84,7 @@ def __init__(self, station_data: StationsResult):

log.info(f"Fit Result message: {out.result.message}")

self.plot_data_and_model(valid_data, out, savefig_to_file=True)
self.plot_data_and_model(valid_data, out, savefig_to_file=True, plot_path=plot_path)

def get_valid_data(self, result_values: pl.DataFrame) -> pl.DataFrame:
valid_data_lst = []
Expand Down Expand Up @@ -137,7 +140,7 @@ def model_pars_update(

return pars

def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefig_to_file=True) -> None:
def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefig_to_file, plot_path: Path) -> None:
"""plots the data and the model"""
if savefig_to_file:
_ = plt.subplots(figsize=(12, 12))
Expand All @@ -153,21 +156,21 @@ def plot_data_and_model(self, valid_data: pl.DataFrame, out: ModelResult, savefi
if savefig_to_file:
number_of_years = valid_data.get_column("date").dt.year().n_unique()
filename = f"{self.__class__.__qualname__}_wetter_model_{number_of_years}"
plt.savefig(filename, dpi=300, bbox_inches="tight")
plt.savefig(plot_path / filename, dpi=300, bbox_inches="tight")
log.info("saved fig to file: " + filename)
if "PYTEST_CURRENT_TEST" not in os.environ:
plt.show()


def main():
def main(plot_path=HERE):
"""Run example."""
logging.basicConfig(level=logging.INFO)

station_data_one_year = station_example(start_date="2020-12-25", end_date="2022-01-01")
_ = ModelYearlyGaussians(station_data_one_year)
_ = ModelYearlyGaussians(station_data_one_year, plot_path=plot_path)

station_data_many_years = station_example(start_date="1995-12-25", end_date="2022-12-31")
_ = ModelYearlyGaussians(station_data_many_years)
_ = ModelYearlyGaussians(station_data_many_years, plot_path=plot_path)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/example/test_regular_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def test_pdbufr_examples():

@pytest.mark.skipif(IS_CI and IS_LINUX, reason="stalls on Mac/Windows in CI")
@pytest.mark.cflake
def test_gaussian_example():
def test_gaussian_example(tmp_path):
from example import observations_station_gaussian_model

assert observations_station_gaussian_model.main() is None
assert observations_station_gaussian_model.main(tmp_path) is None


# @pytest.mark.skipif(IS_CI, reason="radar examples not working in CI")
Expand Down

0 comments on commit a9ce842

Please sign in to comment.