Skip to content

Commit

Permalink
feat: add masking and other improvements (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
benkozi authored Jan 22, 2025
1 parent 95553c6 commit 8ba2a64
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 58 deletions.
11 changes: 8 additions & 3 deletions script/hera/run-smoke-dust-regrid.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
#!/usr/bin/env bash

set -e

RUNDIR=smoke-dust-fixed-files
CONDAENV=~/l/scratch/miniconda/envs/regrid-wrapper

cd ~/l/scratch/sandbox/regrid-wrapper || exit
git pull
rm -rf ${RUNDIR}
rm -rf ${RUNDIR} || echo "run directory does not exist"

export PATH=${CONDAENV}/bin:${PATH}
export PYTHONPATH=$(pwd -LP)/src
export REGRID_WRAPPER_LOG_DIR=.
export ESMFMKFILE=${CONDAENV}/lib/esmf.mk

~/l/scratch/miniconda/envs/regrid-wrapper/bin/python ./src/regrid_wrapper/hydra/task_prep.py || exit
python ./src/regrid_wrapper/hydra/task_prep.py || exit

cd ${RUNDIR}/logs || exit
sbatch ../main-job.sh || exit
squeue -u Benjamin.Koziol -i 5
squeue -u Benjamin.Koziol -i 30
8 changes: 6 additions & 2 deletions src/regrid_wrapper/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from typing import Any


def ncdump(path: Path) -> Any:
ret = subprocess.check_output(["ncdump", "-h", str(path)])
def ncdump(path: Path, header_only: bool = True) -> Any:
args = ["ncdump"]
if header_only:
args.append("-h")
args.append(str(path))
ret = subprocess.check_output(args)
print(ret.decode(), flush=True)
return ret
5 changes: 5 additions & 0 deletions src/regrid_wrapper/concrete/rave_to_rrfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ def run(self) -> None:
unmapped_action=esmpy.UnmappedAction.IGNORE,
ignore_degenerate=True,
)

# Uncomment to test read back from file
# _ = esmpy.RegridFromFile(
# src_field, dst_field, str(self._spec.output_weight_filename)
# )
65 changes: 52 additions & 13 deletions src/regrid_wrapper/concrete/rrfs_dust_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple

import esmpy
import numpy as np

from pydantic import BaseModel, ConfigDict

Expand Down Expand Up @@ -34,6 +35,7 @@ def run(self) -> None:
assert isinstance(self._spec, GenerateWeightFileAndRegridFields)

src_gwrap = self._create_source_grid_wrapper_()
src_gwrap.value.add_item(esmpy.GridItem.MASK)
dst_gwrap = self._create_destination_grid_wrapper_()

archetype_field_name = RRFS_DUST_DATA_ENV.fields[0]
Expand Down Expand Up @@ -68,20 +70,7 @@ def run(self) -> None:
dst_dim.name = src_dim.name
dst_gwrap_output.fill_nc_variables(self._spec.output_filename)

dst_fwrap = self._create_field_wrapper_(
archetype_field_name, self._spec.output_filename, dst_gwrap_output
)

self._logger.info("starting weight file generation")
regrid_method = esmpy.RegridMethod.BILINEAR
regridder = esmpy.Regrid(
src_fwrap.value,
dst_fwrap.value,
regrid_method=regrid_method,
filename=str(self._spec.output_weight_filename),
unmapped_action=esmpy.UnmappedAction.ERROR,
)

for field_to_regrid in RRFS_DUST_DATA_ENV.fields:
self._logger.info(f"regridding field: {field_to_regrid}")
src_fwrap_regrid = self._create_field_wrapper_(
Expand All @@ -90,13 +79,63 @@ def run(self) -> None:
dst_fwrap_regrid = self._create_field_wrapper_(
field_to_regrid, self._spec.output_filename, dst_gwrap_output
)

self._logger.info("updating grid mask")
self._update_grid_mask_and_dst_field_(
src_gwrap, src_fwrap_regrid, dst_fwrap_regrid, field_to_regrid
)

self._logger.info("starting weight file generation")
regridder = esmpy.Regrid(
src_fwrap_regrid.value,
dst_fwrap_regrid.value,
regrid_method=regrid_method,
# filename=str(self._spec.output_weight_filename), # Disable since weight files differ per-variable
unmapped_action=esmpy.UnmappedAction.IGNORE,
src_mask_values=[0],
)

regridder(
src_fwrap_regrid.value,
dst_fwrap_regrid.value,
zero_region=esmpy.Region.SELECT,
)
dst_fwrap_regrid.fill_nc_variable(self._spec.output_filename)

def _update_grid_mask_and_dst_field_(
self,
gwrap: GridWrapper,
src_fwrap: FieldWrapper,
dst_fwrap: FieldWrapper,
varname: str,
) -> None:
mask = gwrap.value.get_item(esmpy.GridItem.MASK)
mask.fill(1) # 1 = unmasked

# Assume that the mask is constant through time
src_field_data = src_fwrap.value.data[:, :, 0]

dst_field_data = dst_fwrap.value.data

self._logger.debug(f"{mask.shape=}")
self._logger.debug(f"{src_field_data.shape=}")
match varname:
case "uthr":
mask[np.where(src_field_data == 999)] = 0
dst_field_data.fill(999)
case "clay":
mask[np.where(src_field_data == -1)] = 0
dst_field_data.fill(-1)
case "ssm":
pass
case "sand":
mask[np.where(src_field_data == -1)] = 0
dst_field_data.fill(-1)
case "rdrag":
pass
case _:
raise NotImplementedError

@staticmethod
def _create_field_wrapper_(
field_name: str, path: Path, gwrap: GridWrapper
Expand Down
71 changes: 42 additions & 29 deletions src/regrid_wrapper/esmpy/field_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from contextlib import contextmanager
from pathlib import Path
from typing import Tuple, Literal, Dict, Sequence, Any
from typing import Tuple, Literal, Dict, Sequence, Any, Union, List

import numpy as np
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
Expand Down Expand Up @@ -37,7 +37,10 @@ def open_nc(
ds.close()


def copy_nc_attrs(src: nc.Dataset | nc.Variable, dst: nc.Dataset | nc.Variable) -> None:
HasNcAttrsType = Union[nc.Dataset, nc.Variable]


def copy_nc_attrs(src: HasNcAttrsType, dst: HasNcAttrsType) -> None:
for attr in src.ncattrs():
if attr.startswith("_"):
continue
Expand Down Expand Up @@ -157,6 +160,8 @@ class AbstractWrapper(abc.ABC, BaseModel):


class GridSpec(BaseModel):
model_config = ConfigDict(frozen=True)

x_center: str
y_center: str
x_dim: NameListType
Expand Down Expand Up @@ -212,27 +217,29 @@ def create_grid_dims(
x_dim, y_dim = self.x_corner_dim, self.y_corner_dim
else:
raise NotImplementedError(staggerloc)
dims = DimensionCollection(
value=[
Dimension(
name=x_dim,
size=get_nc_dimension(ds, x_dim).size,
lower=grid.lower_bounds[staggerloc][self.x_index],
upper=grid.upper_bounds[staggerloc][self.x_index],
staggerloc=staggerloc,
coordinate_type="x",
),
Dimension(
name=y_dim,
size=get_nc_dimension(ds, y_dim).size,
lower=grid.lower_bounds[staggerloc][self.y_index],
upper=grid.upper_bounds[staggerloc][self.y_index],
staggerloc=staggerloc,
coordinate_type="y",
),
]
x_dimobj = Dimension(
name=x_dim,
size=get_nc_dimension(ds, x_dim).size,
lower=grid.lower_bounds[staggerloc][self.x_index],
upper=grid.upper_bounds[staggerloc][self.x_index],
staggerloc=staggerloc,
coordinate_type="x",
)
return dims
y_dimobj = Dimension(
name=y_dim,
size=get_nc_dimension(ds, y_dim).size,
lower=grid.lower_bounds[staggerloc][self.y_index],
upper=grid.upper_bounds[staggerloc][self.y_index],
staggerloc=staggerloc,
coordinate_type="y",
)
if self.x_index == 0:
value = [x_dimobj, y_dimobj]
elif self.x_index == 1:
value = [y_dimobj, x_dimobj]
else:
raise NotImplementedError(self.x_index, self.y_index)
return DimensionCollection(value=value)


class GridWrapper(AbstractWrapper):
Expand Down Expand Up @@ -261,12 +268,7 @@ class NcToGrid(BaseModel):

def create_grid_wrapper(self) -> GridWrapper:
with open_nc(self.path, "r") as ds:
grid_shape = np.array(
[
get_nc_dimension(ds, self.spec.x_dim).size,
get_nc_dimension(ds, self.spec.y_dim).size,
]
)
grid_shape = self._create_grid_shape_(ds)
staggerloc = esmpy.StaggerLoc.CENTER
grid = esmpy.Grid(
grid_shape,
Expand All @@ -293,6 +295,17 @@ def create_grid_wrapper(self) -> GridWrapper:
)
return gwrap

def _create_grid_shape_(self, ds: nc.Dataset) -> np.ndarray:
x_size = get_nc_dimension(ds, self.spec.x_dim).size
y_size = get_nc_dimension(ds, self.spec.y_dim).size
if self.spec.x_index == 0:
grid_shape = (x_size, y_size)
elif self.spec.x_index == 1:
grid_shape = (y_size, x_size)
else:
raise NotImplementedError(self.spec.x_index, self.spec.y_index)
return np.array(grid_shape)

def _add_corner_coords_(
self, ds: nc.Dataset, grid: esmpy.Grid
) -> DimensionCollection:
Expand Down Expand Up @@ -334,7 +347,7 @@ def create_field_wrapper(self) -> FieldWrapper:
ndbounds = None
target_dims = self.gwrap.dims
else:
ndbounds = (get_nc_dimension(ds, self.dim_time).size,)
ndbounds = (len(get_nc_dimension(ds, self.dim_time)),)
time_dim = Dimension(
name=self.dim_time,
size=ndbounds[0],
Expand Down
1 change: 0 additions & 1 deletion src/regrid_wrapper/hydra/conf/smoke-dust-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@ source_definition:
grid: /scratch2/NAGAPE/epic/Ben.Koziol/output-data/RRFS_CONUS_25km.nc
nodes: 2
wall_time: "01:00:00"

5 changes: 4 additions & 1 deletion src/regrid_wrapper/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ class SmokeDustRegridConfig(BaseModel):
source_definition: SourceDefinition

def output_directory(self, target_grid: RrfsGridKey) -> PathType:
return self.root_output_directory / f"fix_smoke/{target_grid.value}"
return (
self.root_output_directory
/ f"fix_smoke/{target_grid.value.replace('KM', 'km')}"
)

@property
def log_directory(self) -> PathType:
Expand Down
4 changes: 2 additions & 2 deletions src/test/test_concrete/test_rave_to_rrfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_dev(bin_dir: Path, tmp_path_shared: Path) -> None:
src_path=bin_dir / "RAVE/grid_in.nc",
dst_path=bin_dir / "RRFS_CONUS_25km/ds_out_base.nc",
output_weight_filename=tmp_path_shared / "weights.nc",
esmpy_debug=True,
esmpy_debug=False,
name="tester",
)
op = RaveToRrfs(spec=spec)
Expand All @@ -44,7 +44,7 @@ def test(tmp_path_shared: Path) -> None:
src_path=src_grid,
dst_path=dst_grid,
output_weight_filename=weights,
esmpy_debug=True,
esmpy_debug=False,
name="tester",
)
op = RaveToRrfs(spec=spec)
Expand Down
4 changes: 2 additions & 2 deletions src/test/test_concrete/test_rrfs_dust_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def test(tmp_path_shared: Path) -> None:
dst_path=dst_grid,
output_weight_filename=weights,
output_filename=dust_data,
esmpy_debug=True,
esmpy_debug=False,
name="dust-data",
fields=RRFS_DUST_DATA_ENV.fields,
)
op = RrfsDustData(spec=spec)
processor = RegridProcessor(operation=op)
processor.execute()

assert weights.exists()
assert not weights.exists()
assert dust_data.exists()

if COMM.rank == 0:
Expand Down
6 changes: 3 additions & 3 deletions src/test/test_concrete/test_rrfs_smoke_dust_veg_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_dev(bin_dir: Path, tmp_path_shared: Path) -> None:
output_weight_filename=tmp_path_shared / weight_filename,
output_filename=tmp_path_shared / "veg_map.nc",
fields=("emiss_factor",),
esmpy_debug=True,
esmpy_debug=False,
name=weight_filename,
)
op = RrfsSmokeDustVegetationMap(spec=spec)
Expand All @@ -52,7 +52,7 @@ def test(tmp_path_shared: Path) -> None:
output_weight_filename=weights,
output_filename=veg_map,
fields=("emiss_factor",),
esmpy_debug=True,
esmpy_debug=False,
name="veg_map-3km-to-25km",
)
op = RrfsSmokeDustVegetationMap(spec=spec)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_different_shapes(tmp_path_shared: Path) -> None:
output_weight_filename=weights,
output_filename=veg_map,
fields=("emiss_factor",),
esmpy_debug=True,
esmpy_debug=False,
name="veg_map-3km-to-25km",
esmpy_unmapped_action=esmpy.UnmappedAction.IGNORE,
)
Expand Down
Loading

0 comments on commit 8ba2a64

Please sign in to comment.