Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scaling for HPX dataloaders, code cleanup #721

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 4 additions & 41 deletions modulus/datapipes/healpix/coupledtimeseries_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,43 +157,10 @@ def _get_scaling_da(self):
scaling_df.loc["zeros"] = {"mean": 0.0, "std": 1.0}
scaling_da = scaling_df.to_xarray().astype("float32")

# only thing we do different here is get the scaling for the coupled values
for c in self.couplings:
c.set_scaling(scaling_da)
# REMARK: we remove the xarray overhead from these
try:
self.input_scaling = scaling_da.sel(index=self.input_variables).rename(
{"index": "channel_in"}
)
self.input_scaling = {
"mean": np.expand_dims(
self.input_scaling["mean"].to_numpy(), (0, 2, 3, 4)
),
"std": np.expand_dims(
self.input_scaling["std"].to_numpy(), (0, 2, 3, 4)
),
}
except (ValueError, KeyError):
raise KeyError(
f"one or more of the input data variables f{list(self.ds.channel_in)} not found in the "
f"scaling config dict data.scaling ({list(self.scaling.keys())})"
)
try:
self.target_scaling = scaling_da.sel(index=self.input_variables).rename(
{"index": "channel_out"}
)
self.target_scaling = {
"mean": np.expand_dims(
self.target_scaling["mean"].to_numpy(), (0, 2, 3, 4)
),
"std": np.expand_dims(
self.target_scaling["std"].to_numpy(), (0, 2, 3, 4)
),
}
except (ValueError, KeyError):
raise KeyError(
f"one or more of the target data variables f{list(self.ds.channel_out)} not found in the "
f"scaling config dict data.scaling ({list(self.scaling.keys())})"
)
super()._get_scaling_da()

def __getitem__(self, item):
# start range
Expand Down Expand Up @@ -251,7 +218,6 @@ def __getitem__(self, item):
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("CoupledTimeSeriesDataset:__getitem__:process_batch")
compute_time = time.time()
# Insolation
if self.add_insolation:
sol = insolation(
Expand Down Expand Up @@ -305,11 +271,9 @@ def __getitem__(self, item):
np.transpose(x, axes=(0, 3, 1, 2, 4, 5)) for x in inputs_result
]

if "constants" in self.ds.data_vars:
if self.constants is not None:
# Add the constants as [F, C, H, W]
inputs_result.append(np.swapaxes(self.ds.constants.values, 0, 1))
# inputs_result.append(self.ds.constants.values)
logger.log(5, "computed batch in %0.2f s", time.time() - compute_time)
inputs_result.append(self.constants)

# append integrated couplings
inputs_result.append(integrated_couplings)
Expand All @@ -328,7 +292,6 @@ def __getitem__(self, item):
return inputs_result, targets

def next_integration(self, model_outputs, constants):

inputs_result = []

# grab last few model outputs for re-initialization
Expand Down
12 changes: 5 additions & 7 deletions modulus/datapipes/healpix/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def open_time_series_dataset_classic_on_the_fly(
file_name = _get_file_name(directory, prefix, variable, suffix)
logger.debug("open nc dataset %s", file_name)

ds = xr.open_dataset(file_name, chunks={"sample": batch_size}, autoclose=True)
ds = xr.open_dataset(file_name, autoclose=True)

if "LL" in prefix:
ds = ds.rename({"lat": "height", "lon": "width"})
Expand Down Expand Up @@ -212,7 +212,7 @@ def open_time_series_dataset_classic_prebuilt(
if not ds_path.exists():
raise FileNotFoundError(f"Dataset doesn't appear to exist at {ds_path}")

result = xr.open_zarr(ds_path, chunks={"time": batch_size})
result = xr.open_zarr(ds_path)
return result


Expand Down Expand Up @@ -285,12 +285,10 @@ def create_time_series_dataset_classic(
for variable in all_variables:
file_name = _get_file_name(src_directory, prefix, variable, suffix)
logger.debug("open nc dataset %s", file_name)
if "sample" in list(xr.open_dataset(file_name).dims.keys()):
ds = xr.open_dataset(file_name, chunks={"sample": batch_size}).rename(
{"sample": "time"}
)
if "sample" in list(xr.open_dataset(file_name).sizes.keys()):
ds = xr.open_dataset(file_name).rename({"sample": "time"})
else:
ds = xr.open_dataset(file_name, chunks={"time": batch_size})
ds = xr.open_dataset(file_name)
if "varlev" in ds.dims:
ds = ds.isel(varlev=0)

Expand Down
75 changes: 59 additions & 16 deletions modulus/datapipes/healpix/timeseries_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,24 +189,35 @@ def __init__(

self.input_scaling = None
self.target_scaling = None
self.constant_scaling = None
self.constants = None
if self.scaling:
self._get_scaling_da()

# setup constants
if "constants" in self.ds.data_vars:
# extract from ds:
const = self.ds.constants.values

if self.scaling:
const = (const - self.constant_scaling["mean"]) / self.constant_scaling[
"std"
]

# transpose to match new format:
# [C, F, H, W] -> [F, C, H, W]
self.constants = np.transpose(const, axes=(1, 0, 2, 3))

self.get_constants()

def get_constants(self):
"""Returns the constants used in this dataset

Returns
-------
np.ndarray: The list of constants, None if there are no constants
"""
# extract from ds:
const = self.ds.constants.values

# transpose to match new format:
# [C, F, H, W] -> [F, C, H, W]
const = np.transpose(const, axes=(1, 0, 2, 3))

return const
return self.constants

@staticmethod
def _convert_time_step(dt): # pylint: disable=invalid-name
Expand Down Expand Up @@ -244,9 +255,13 @@ def _get_scaling_da(self):
),
}
except (ValueError, KeyError):
missing = [
m
for m in self.ds.channel_in.values
if m not in list(self.scaling.keys())
]
raise KeyError(
f"one or more of the input data variables f{list(self.ds.channel_in)} not found in the "
f"scaling config dict data.scaling ({list(self.scaling.keys())})"
f"Input channels {missing} not found in the scaling config dict data.scaling ({list(self.scaling.keys())})"
)
try:
self.target_scaling = scaling_da.sel(
Expand All @@ -261,9 +276,37 @@ def _get_scaling_da(self):
),
}
except (ValueError, KeyError):
missing = [
m
for m in self.ds.channel_out.values
if m not in list(self.scaling.keys())
]
raise KeyError(
f"one or more of the target data variables f{list(self.ds.channel_out)} not found in the "
f"scaling config dict data.scaling ({list(self.scaling.keys())})"
f"Target channels {missing} not found in the scaling config dict data.scaling ({list(self.scaling.keys())})"
)

try:
# not all datasets will have constants
if "constants" in self.ds.data_vars:
self.constant_scaling = scaling_da.sel(
index=self.ds.channel_c.values
).rename({"index": "channel_out"})
self.constant_scaling = {
"mean": np.expand_dims(
self.constant_scaling["mean"].to_numpy(), (1, 2, 3)
),
"std": np.expand_dims(
self.constant_scaling["std"].to_numpy(), (1, 2, 3)
),
}
except (ValueError, KeyError):
missing = [
m
for m in self.ds.channel_c.values
if m not in list(self.scaling.keys())
]
raise KeyError(
f"Constant channels {missing} not found in the scaling config dict data.scaling ({list(self.scaling.keys())})"
)

def __len__(self):
Expand Down Expand Up @@ -303,8 +346,8 @@ def _get_time_index(self, item):
if self.forecast_mode
else (item + 1) * self.batch_size + self._window_length
)
if not self.drop_last and max_index > self.ds.dims["time"]:
batch_size = self.batch_size - (max_index - self.ds.dims["time"])
if not self.drop_last and max_index > self.ds.sizes["time"]:
batch_size = self.batch_size - (max_index - self.ds.sizes["time"])
else:
batch_size = self.batch_size
return (start_index, max_index), batch_size
Expand Down Expand Up @@ -428,9 +471,9 @@ def __getitem__(self, item):
np.transpose(x, axes=(0, 3, 1, 2, 4, 5)) for x in inputs_result
]

if "constants" in self.ds.data_vars:
if self.constants is not None:
# Add the constants as [F, C, H, W]
inputs_result.append(np.swapaxes(self.ds.constants.values, 0, 1))
inputs_result.append(self.constants)

logger.log(5, "computed batch in %0.2f s", time.time() - compute_time)
torch.cuda.nvtx.range_pop()
Expand Down
13 changes: 10 additions & 3 deletions test/datapipes/test_healpix.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def scaling_dict():
"z1000": {"mean": 952.1435546875, "std": 895.7516479492188},
"z250": {"mean": 101186.28125, "std": 5551.77978515625},
"z500": {"mean": 55625.9609375, "std": 2681.712890625},
"lsm": {"mean": 0, "std": 1},
"z": {"mean": 0, "std": 1},
"tp6": {"mean": 1, "std": 0, "log_epsilon": 1e-6},
"extra": {"mean": 1, "std": 0},
}
return DictConfig(scaling)

Expand All @@ -88,6 +91,9 @@ def scaling_double_dict():
"z250": {"mean": 0, "std": 2},
"z500": {"mean": 0, "std": 2},
"tp6": {"mean": 0, "std": 2, "log_epsilon": 1e-6},
"lsm": {"mean": 0, "std": 2},
"z": {"mean": 0, "std": 2},
"extra": {"mean": 0, "std": 2},
}
return DictConfig(scaling)

Expand Down Expand Up @@ -212,7 +218,7 @@ def test_TimeSeriesDataset_initialization(
"bogosity": {"mean": 0, "std": 42},
}
)
with pytest.raises(KeyError, match=("one or more of the input data variables")):
with pytest.raises(KeyError, match=("Input channels ")):
timeseries_ds = TimeSeriesDataset(
dataset=zarr_ds,
data_time_step="3h",
Expand Down Expand Up @@ -550,7 +556,8 @@ def test_TimeSeriesDataModule_get_constants(
# open our test dataset
ds_path = Path(data_dir, dataset_name + ".zarr")
zarr_ds = xr.open_zarr(ds_path)
expected = np.transpose(zarr_ds.constants.values, axes=(1, 0, 2, 3))
# dividing by 2 due to scaling
expected = np.transpose(zarr_ds.constants.values, axes=(1, 0, 2, 3)) / 2

assert np.array_equal(
timeseries_dm.get_constants(),
Expand Down Expand Up @@ -617,7 +624,7 @@ def test_TimeSeriesDataModule_get_dataloaders(
test_dataloader, test_sampler = timeseries_dm.test_dataloader(num_shards=1)
assert test_sampler is None
assert isinstance(test_dataloader, DataLoader)
print(f"dataset lenght {len}")

# with >1 shard should be distributed sampler
train_dataloader, train_sampler = timeseries_dm.train_dataloader(num_shards=2)
assert isinstance(train_sampler, DistributedSampler)
Expand Down
Loading