Skip to content

Commit

Permalink
Cleaned up save_to_vtk and wrote basic safety tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mdmeeker committed Mar 1, 2024
1 parent e426988 commit 3f273ef
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 29 deletions.
4 changes: 2 additions & 2 deletions .github/new_windgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from drdmannturb.fluctuation_generation import (
plot_velocity_components, # utility function for plotting each velocity component in the field, not used in this example
format_wind_field
)
from drdmannturb.fluctuation_generation import (
GenerateFluctuationField,
Expand Down Expand Up @@ -73,9 +74,8 @@
fluctuation_field_drd += mean_profile
fluctuation_field_drd *= 40/63

wind_field_vtk = tuple([np.copy(fluctuation_field_drd[...,i], order='C') for i in range(3)])
wind_field_vtk = format_wind_field(fluctuation_field_drd)
cellData = {'grid': np.zeros_like(fluctuation_field_drd[...,0]), 'wind': wind_field_vtk}

FileName = f"dat/block_{nBlocks}"
imageToVTK(FileName, cellData = cellData, spacing=spacing)

Expand Down
102 changes: 77 additions & 25 deletions drdmannturb/fluctuation_generation/fluctuation_field_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def generate(
import warnings

warnings.warn(
"Fluctuation field has already been generated, additional blocks will be appended to existing field. If this is undesirable behavior, instantiate a new object."
"Fluctuation field has already been generated, additional blocks will be appended to existing field. If this is undesired behavior, you should instead create a new instance."
)

for _ in range(num_blocks):
Expand Down Expand Up @@ -400,7 +400,7 @@ def normalize(
"""
if not np.any(self.total_fluctuation):
raise ValueError(
"No fluctuation field has been generated, call the .generate() method first."
"No fluctuation field has been generated. Call the .generate() method first."
)

sd = np.sqrt(np.mean(self.total_fluctuation**2))
Expand All @@ -418,29 +418,6 @@ def normalize(

return self.total_fluctuation + mean_profile

def save_to_vtk(self, filepath: Union[str, Path] = "./"):
"""Saves generated fluctuation field in VTK format to specified filepath.
Parameters
----------
filepath : Union[str, Path]
Filepath to which to save generated fluctuation field.
"""
from pyevtk.hl import imageToVTK

spacing = tuple(self.grid_dimensions / (2.0**self.grid_levels + 1))

wind_field_vtk = tuple(
[np.copy(self.total_fluctuation[..., i], order="C") for i in range(3)]
)

cellData = {
"grid": np.zeros_like(self.total_fluctuation[..., 0]),
"wind": wind_field_vtk,
}

imageToVTK(filepath, cellData=cellData, spacing=spacing)

def evaluate_divergence(
self, spacing: Union[tuple, np.ndarray], field: Optional[np.ndarray] = None
) -> np.ndarray:
Expand Down Expand Up @@ -492,3 +469,78 @@ def evaluate_divergence(
return np.ufunc.reduce(
np.add, [np.gradient(field[..., i], spacing[i], axis=i) for i in range(3)]
)

def save_to_vtk(self, filename: str, filepath: Union[str, Path] = "./DRDMT_out") -> None:
"""Saves generated fluctuation field in VTK format to specified filepath.
Parameters
----------
filename : str
File name to write out with
filepath : Union[str, Path], optional
File path to which the generated fluctuation field VTK should be written, by default "./DRDMT_out"
Raises
------
ValueError
Thrown if provided filepath does not lead to a directory
ValueError
Thrown if provided filename is empty/invalid
"""
from pyevtk.hl import imageToVTK
from .wind_plot import format_wind_field

path : Path
if isinstance(filepath, str):
path = Path(filepath)
else:
path = filepath

if not path.is_dir():
raise ValueError("Provided value for filepath does not lead to a directory")

if not (len(filename) > 0):
raise ValueError("Was passed an empty string as a file name")

spacing = tuple(self.grid_dimensions / (2.0**self.grid_levels + 1))
wind_field_vtk = format_wind_field(self.total_fluctuation)
cellData = {
"grid": np.zeros_like(self.total_fluctuation[..., 0]),
"wind": wind_field_vtk,
}

imageToVTK(filepath, cellData=cellData, spacing=spacing)


def save_to_netcdf(self, filename: str, filepath: Union[str, Path] = "./DRDMT_out") -> None:
"""Saves generated fluctuation field in NetCDF format to specified filepath.
Parameters
----------
filename : str
File name to write out with
filepath : Union[str, Path], optional
File path to which the generated fluctuation field NetCDF should be written, by default "./DRDMT_out"
Raises
------
ValueError
Thrown if provided filepath does not lead to a directory
ValueError
Thrown if provided filename is nonsensical/invalid
"""
from .wind_plot import format_wind_field

path : Path
if isinstance(filepath, str):
path = Path(filepath)
else:
path = filepath

if not path.is_dir():
raise ValueError("Provided value for filepath does not lead to a directory")

if not (len(filename) > 0):
raise ValueError("Was passed an empty string as a file name")

pass
2 changes: 1 addition & 1 deletion test/eddy_lifetime/test_symmetries.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_synth_basic():
k4_n2 = torch.stack([k_gd, -k_gd, k_gd], dim=-1) / 3 ** (1 / 2)
tau_model4_neg2 = pb.OPS.EddyLifetime(k4_n2).cpu().detach().numpy()

assert np.array_equal(tau_model4, tau_model4_neg2), "tau function is even wrt k2"
assert np.array_equal(tau_model4, tau_model4_neg2), "tau function is even wrt k4"


if __name__ == "__main__":
Expand Down
76 changes: 75 additions & 1 deletion test/io/test_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from drdmannturb.enums import DataType
from drdmannturb.spectra_fitting import OnePointSpectraDataGenerator
from drdmannturb.fluctuation_generation import GenerateFluctuationField



path = Path().resolve()

Expand All @@ -18,7 +21,6 @@
if torch.cuda.is_available():
torch.set_default_tensor_type("torch.cuda.FloatTensor")


spectra_file = path / "../../docs/source/data/Spectra.dat"

domain = torch.logspace(-1, 3, 40)
Expand All @@ -32,6 +34,10 @@


def test_custom_spectra_load():
"""
These tests ensure that the data loading features match freshly generated
data given the same parameters
"""

CustomData = torch.tensor(
np.genfromtxt(spectra_file, skip_header=1, delimiter=","), dtype=torch.float
Expand All @@ -50,3 +56,71 @@ def test_custom_spectra_load():
assert torch.equal(CustomData[:, 2], Data[1][:, 1, 1])
assert torch.equal(CustomData[:, 3], Data[1][:, 2, 2])
assert torch.equal(CustomData[:, 4], -Data[1][:, 0, 2])


def test_vtk_netcdf_io_errors():
"""
These tests ensure that the VTK file I/O routines cleanly succeed
and fail when they should
"""
friction_velocity = 2.683479938442173
reference_height = 180.0
grid_dimensions = np.array([300.0, 864.0, 576.0]) #* 1/20#* 1/10
grid_levels = np.array([6, 6, 8])
seed = None # 9000
Type_Model = "NN"
path_to_parameters = (
path / "../docs/source/results/EddyLifetimeType.CUSTOMMLP_DataType.KAIMAL.pkl"
if path.name == "examples"
else path / "../results/EddyLifetimeType.CUSTOMMLP_DataType.KAIMAL.pkl"
)

gen_drd = GenerateFluctuationField(
friction_velocity,
reference_height,
grid_dimensions,
grid_levels,
model=Type_Model,
path_to_parameters=path_to_parameters,
seed=seed
)
gen_drd.generate(1)

"""
Test empty string filename fail
"""
with pytest.raises(ValueError) as e_info:
gen_drd.save_to_vtk("")

with pytest.raises(ValueError) as e_info:
gen_drd.save_to_netcdf("")

"""
Test ill-formed file path fail
- leads to a non-directory
"""
import os
f = open("DUMMY.txt", "x")
assert os.path.exists("DUMMY.txt"), "DUMMY.txt is not where it was expected"

with pytest.raises(ValueError) as e_info:
gen_drd.save_to_vtk(filename="TEST", filepath="DUMMY.txt")

with pytest.raises(ValueError) as e_info:
gen_drd.save_to_netcdf(filename="TEST", filepath="DUMMY.txt")

os.remove("DUMMY.txt")

"""
- leads to a non-existent directory
"""
with pytest.raises(ValueError) as e_info:
gen_drd.save_to_vtk(filename="TEST", filepath=Path("./this/does/not/exist"))

with pytest.raises(ValueError) as e_info:
gen_drd.save_to_netcdf(filename="TEST", filepath=Path("./this/does/not/exist"))


if __name__ == "__main__":
test_custom_spectra_load()
test_vtk_netcdf_io_errors()

0 comments on commit 3f273ef

Please sign in to comment.