Skip to content

Commit

Permalink
Merge pull request #67 from oceanmodeling/feature/enforce_input_version
Browse files Browse the repository at this point in the history
Check and enforce input version
  • Loading branch information
SorooshMani-NOAA authored Aug 21, 2024
2 parents 1efecce + edb27e5 commit 9edaa49
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 4 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: tests

on:
push:
branches:
- main
paths:
- '**.py'
- '.github/workflows/tests.yml'
- 'pyproject.toml'
pull_request:
branches:
- main

jobs:
test:
name: test
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ ubuntu-latest ]
python-version: [ '3.9', '3.10', '3.11' ]
steps:
- name: clone repository
uses: actions/checkout@v4
- name: conda virtual environment
uses: mamba-org/setup-micromamba@v1
with:
init-shell: bash
environment-file: environment.yml
- name: install the package
run: pip install ".[dev]"
shell: micromamba-shell {0}
- name: run tests
run: pytest
shell: micromamba-shell {0}
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ description = "A set of scripts to generate probabilistic storm surge results!"

license = {file = "LICENSE"}

requires-python = ">= 3.8, < 3.12"
requires-python = ">= 3.9, < 3.12"

dependencies = [
"cartopy",
Expand All @@ -45,6 +45,7 @@ dependencies = [
"numpy",
"numba",
"ocsmesh==1.5.3",
"packaging",
"pandas",
"pyarrow",
"pygeos",
Expand All @@ -65,6 +66,11 @@ dependencies = [
"xarray",
]

[project.optional-dependencies]
dev = [
"pytest"
]

[tool.setuptools_scm]
version_file = "stormworkflow/_version.py"

Expand Down
63 changes: 61 additions & 2 deletions stormworkflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,74 @@
import logging
import os
import shlex
import warnings
from importlib.resources import files
from argparse import ArgumentParser
from pathlib import Path

import stormworkflow
import yaml
from packaging.version import Version
try:
from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
from yaml import Loader, Dumper

import stormworkflow

_logger = logging.getLogger(__file__)

CUR_INPUT_VER = Version('0.0.2')


def _handle_input_v0_0_1_to_v0_0_2(inout_conf):

ver = Version(inout_conf['input_version'])

# Only update config if specified version matches the assumed one
if ver != Version('0.0.1'):
return ver


_logger.info(
"Adding perturbation variables for persistent RMW perturbation"
)
inout_conf['perturb_vars'] = [
'cross_track',
'along_track',
'radius_of_maximum_winds_persistent',
'max_sustained_wind_speed',
]

return Version('0.0.2')


def handle_input_version(inout_conf):

if 'input_version' not in inout_conf:
ver = CUR_INPUT_VER
warnings.warn(
f"`input_version` is NOT specified in `input.yaml`; assuming {ver}"
)
inout_conf['input_version'] = str(ver)
return

ver = Version(inout_conf['input_version'])

if ver > CUR_INPUT_VER:
raise ValueError(
f"Input version not supported! Max version supported is {CUR_INPUT_VER}"
)

ver = _handle_input_v0_0_1_to_v0_0_2(inout_conf)

if ver != CUR_INPUT_VER:
raise ValueError(
f"Could NOT update input to the latest version! Updated to {ver}"
)

inout_conf['input_version'] = str(ver)


def main():

parser = ArgumentParser()
Expand All @@ -28,12 +82,17 @@ def main():

infile = args.configuration
if infile is None:
_logger.warn('No input configuration provided, using reference file!')
warnings.warn(
'No input configuration provided, using reference file!'
)
infile = refs.joinpath('input.yaml')

with open(infile, 'r') as yfile:
conf = yaml.load(yfile, Loader=Loader)

handle_input_version(conf)
# TODO: Write out the updated config as a yaml file

wf = scripts.joinpath('workflow.sh')

run_env = os.environ.copy()
Expand Down
5 changes: 4 additions & 1 deletion stormworkflow/scripts/workflow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ function init {
done

logfile=$run_dir/versions.info
version $logfile stormworkflow
version $logfile stormevents
version $logfile ensembleperturbation
version $logfile coupledmodeldriver
version $logfile pyschism
version $logfile ocsmesh
echo "SCHISM: see solver.version each outputs dir" >> $logfile

cp $input_file $run_dir/input.yaml
cp $input_file $run_dir/input_asis.yaml

echo $run_dir
}
Expand Down
40 changes: 40 additions & 0 deletions tests/data/refs/input_v0.0.1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
---
input_version: 0.0.1

storm: "florence"
year: 2018
suffix: ""
subset_mesh: 1
hr_prelandfall: -1
past_forecast: 1
hydrology: 0
use_wwm: 0
pahm_model: "gahm"
num_perturb: 2
sample_rule: "korobov"
spinup_exec: "pschism_PAHM_TVD-VL"
hotstart_exec: "pschism_PAHM_TVD-VL"

hpc_solver_nnodes: 3
hpc_solver_ntasks: 108
hpc_account: ""
hpc_partition: ""

RUN_OUT: ""
L_NWM_DATASET: ""
L_TPXO_DATASET: ""
L_LEADTIMES_DATASET: ""
L_TRACK_DIR: ""
L_DEM_HI: ""
L_DEM_LO: ""
L_MESH_HI: ""
L_MESH_LO: ""
L_SHP_DIR: ""

TMPDIR: "/tmp"
PATH_APPEND: ""

L_SOLVE_MODULES:
- "intel/2022.1.2"
- "impi/2022.1.2"
- "netcdf"
46 changes: 46 additions & 0 deletions tests/data/refs/input_v0.0.2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
---
input_version: 0.0.2

storm: "florence"
year: 2018
suffix: ""
subset_mesh: 1
hr_prelandfall: -1
past_forecast: 1
hydrology: 0
use_wwm: 0
pahm_model: "gahm"
num_perturb: 2
sample_rule: "korobov"
spinup_exec: "pschism_PAHM_TVD-VL"
hotstart_exec: "pschism_PAHM_TVD-VL"
perturb_vars:
- 'cross_track'
- 'along_track'
# - 'radius_of_maximum_winds'
- 'radius_of_maximum_winds_persistent'
- 'max_sustained_wind_speed'

hpc_solver_nnodes: 3
hpc_solver_ntasks: 108
hpc_account: ""
hpc_partition: ""

RUN_OUT: ""
L_NWM_DATASET: ""
L_TPXO_DATASET: ""
L_LEADTIMES_DATASET: ""
L_TRACK_DIR: ""
L_DEM_HI: ""
L_DEM_LO: ""
L_MESH_HI: ""
L_MESH_LO: ""
L_SHP_DIR: ""

TMPDIR: "/tmp"
PATH_APPEND: ""

L_SOLVE_MODULES:
- "intel/2022.1.2"
- "impi/2022.1.2"
- "netcdf"
67 changes: 67 additions & 0 deletions tests/test_input_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from copy import deepcopy
from importlib.resources import files

import pytest
import yaml
from packaging.version import Version
from yaml import Loader, Dumper

from stormworkflow.main import handle_input_version, CUR_INPUT_VER


refs = files('tests.data.refs')
input_v0_0_1 = refs.joinpath('input_v0.0.1.yaml')
input_v0_0_2 = refs.joinpath('input_v0.0.2.yaml')


def read_conf(infile):
with open(infile, 'r') as yfile:
conf = yaml.load(yfile, Loader=Loader)
return conf


@pytest.fixture
def conf_v0_0_1():
return read_conf(input_v0_0_1)


@pytest.fixture
def conf_v0_0_2():
return read_conf(input_v0_0_2)


@pytest.fixture
def conf_latest(conf_v0_0_2):
return conf_v0_0_2


def test_no_version_specified(conf_latest):
conf_latest.pop('input_version')
with pytest.warns(UserWarning):
handle_input_version(conf_latest)

assert conf_latest['input_version'] == str(CUR_INPUT_VER)


def test_invalid_version_specified(conf_latest):

invalid_1 = deepcopy(conf_latest)
invalid_1['input_version'] = (
f'{CUR_INPUT_VER.major}.{CUR_INPUT_VER.minor}.{CUR_INPUT_VER.micro + 1}'
)
with pytest.raises(ValueError) as e:
handle_input_version(invalid_1)

assert "max" in str(e.value).lower()


invalid_2 = deepcopy(conf_latest)
invalid_2['input_version'] = 'a.b.c'
with pytest.raises(ValueError) as e:
handle_input_version(invalid_2)
assert "invalid version" in str(e.value).lower()


def test_v0_0_1_to_v0_0_2(conf_v0_0_1, conf_v0_0_2):
handle_input_version(conf_v0_0_1)
assert conf_v0_0_2 == conf_v0_0_1

0 comments on commit 9edaa49

Please sign in to comment.