Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No longer import NativeGridDataset from anemoi-training #26

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 42 additions & 154 deletions bris/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
from functools import cached_property
from typing import Any, Optional
from typing import Any, Optional, TYPE_CHECKING

import numpy as np
import pytorch_lightning as pl
Expand All @@ -10,14 +9,16 @@
from anemoi.utils.dates import frequency_to_seconds
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf, errors
from torch.utils.data import DataLoader, get_worker_info
from torch_geometric.data import HeteroData
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
import anemoi.datasets.data.subset
import anemoi.datasets.data.select

from bris.checkpoint import Checkpoint
from bris.data.dataset import Dataset
from bris.utils import check_anemoi_training, recursive_list_to_tuple
from bris.data.dataset import worker_init_func
from bris.data.grid_indices import FullGrid
from bris.utils import recursive_list_to_tuple

from bris.data.grid_indices import BaseGridIndices

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,8 +48,7 @@ def __init__(
self.ckptObj = checkpoint_object
self.timestep = config.timestep
self.frequency = config.frequency
self.legacy = not check_anemoi_training(metadata=self.ckptObj._metadata)


def predict_dataloader(self) -> DataLoader:
"""
Creates a dataloader for prediction
Expand All @@ -57,24 +57,9 @@ def predict_dataloader(self) -> DataLoader:
None
return:

"""
return self._get_dataloader(self.ds_predict)

def _get_dataloader(self, ds):
"""
Creates torch dataloader object for
ds. Batch_size, num_workers, prefetch_factor
and pin_memory can be adjusted in the config
under dataloader.

args:
ds: anemoi.datasets.data.open_dataset object

return:
torch dataloader initialized on anemoi dataset object
"""
return DataLoader(
ds,
self.ds_predict,
batch_size=self.config.dataloader.batch_size,
# number of worker processes
num_workers=self.config.dataloader.num_workers,
Expand All @@ -99,100 +84,22 @@ def ds_predict(self) -> Any:
Anemoi dataset open_dataset object
"""
return self._get_dataset(self.data_reader)

def _get_dataset(
self,
data_reader,
):
"""
Instantiates a given dataset class
from anemoi.training.data.dataset.
This assumes that the python path for
the class is defined, and anemoi-training
for a given branch is installed with pip
in order to access the class. This
method returns an instantiated instance of
a given data class. This supports
data distributed parallel (DDP) and model
sharding.

args:
data_reader: anemoi open_dataset object

return:
an dataset class object
"""
if self.legacy:
# TODO: fix imports and pip packages for legacy version
LOGGER.info(
"""Did not find anemoi.training version in checkpoint metadata, assuming
the model was trained with aifs-mono and using legacy functionality"""
)
LOGGER.warning("WARNING! Ensemble legacy mode has yet to be implemented!")
from .legacy.dataset import EnsNativeGridDataset, NativeGridDataset
from .legacy.utils import _legacy_slurm_proc_id

model_comm_group_rank, model_comm_group_id, model_comm_num_groups = (
_legacy_slurm_proc_id(self.config)
)

spatial_mask = {}
for mesh_name, mesh in self.graph.items():
if (
isinstance(mesh_name, str)
and mesh_name
!= self.ckptObj._metadata.config.graphs.hidden_mesh.name
):
spatial_mask[mesh_name] = mesh.get("dataset_idx", None)
spatial_index = spatial_mask[
self.ckptObj._metadata.config.graphs.encoders[0]["src_mesh"]
]

dataCls = NativeGridDataset(
data_reader=data_reader,
rollout=0, # we dont perform rollout during inference
multistep=self.ckptObj.multistep,
timeincrement=self.timeincrement,
model_comm_group_rank=model_comm_group_rank,
model_comm_group_id=model_comm_group_id,
model_comm_num_groups=model_comm_num_groups,
spatial_index=spatial_index,
shuffle=False,
label="predict",
)
return dataCls
else:
try:
dataCls = instantiate(
config=self.config.dataloader.datamodule,
data_reader=data_reader,
rollout=0, # we dont perform rollout during inference
multistep=self.ckptObj.multistep,
timeincrement=self.timeincrement,
grid_indices=self.grid_indices,
shuffle=False,
label="predict",
)
except:
dataCls = instantiate(
config=self.config.dataloader.datamodule,
data_reader=data_reader,
rollout=0, # we dont perform rollout during inference
multistep=self.ckptObj.multistep,
timeincrement=self.timeincrement,
shuffle=False,
label="predict",
)

return Dataset(dataCls)
) -> IterableDataset:
ds = instantiate(
config=self.config.dataloader.datamodule,
data_reader=data_reader,
rollout=0,
multistep=self.ckptObj.multistep,
timeincrement=self.timeincrement,
grid_indices=self.grid_indices,
label="predict",
)

@property
def name_to_index(self):
"""
Returns a tuple of dictionaries, where each dict is:
variable_name -> index
"""
return self.ckptObj.name_to_index
return ds

@cached_property
def data_reader(self):
Expand Down Expand Up @@ -243,20 +150,27 @@ def timeincrement(self) -> int:
timestep,
)
return timestep // frequency

@cached_property
def grid_indices(self):
if check_anemoi_training(self.ckptObj.metadata):
try:
from anemoi.training.data.grid_indices import BaseGridIndices, FullGrid
except ImportError as e:
print("Warning! Could not import BaseGridIndices and FullGrid. Continuing without this module")

@property
def name_to_index(self):
"""
Returns a tuple of dictionaries, where each dict is:
variable_name -> index
"""
return self.ckptObj.name_to_index

@cached_property
def grid_indices(self) -> type[BaseGridIndices]:
reader_group_size = self.config.dataloader.read_group_size
grid_indices = FullGrid(
nodes_name="data",
reader_group_size=reader_group_size
if hasattr(self.config.dataloader, "grid_indices"):
grid_indices = instantiate(self.config.dataloder.grid_indices, reader_group_size=reader_group_size)
LOGGER.info("Using grid indices from dataloader config")
else:
grid_indices = FullGrid(
nodes_name="data",
reader_group_size=reader_group_size
)
LOGGER.info("grid_indices not found in dataloader config, defaulting to FullGrid")
grid_indices.setup(self.graph)
return grid_indices

Expand Down Expand Up @@ -311,10 +225,11 @@ def altitudes(self) -> tuple:

return altitudes


@cached_property
def field_shape(self) -> tuple:

"""
Retrieve field_shape of the datasets
"""
field_shape = [None]*len(self.grids)
for decoder_index, grids in enumerate(self.grids):
field_shape[decoder_index] = [None]*len(grids)
Expand Down Expand Up @@ -344,30 +259,3 @@ def _get_field_shape(self, decoder_index, dataset_index):
assert (decoder_index == 0 and dataset_index == 0)
return data_reader.field_shape

def worker_init_func(worker_id: int) -> None:
"""Configures each dataset worker process.

Calls WeatherBenchDataset.per_worker_init() on each dataset object.

Parameters
----------
worker_id : int
Worker ID

Raises
------
RuntimeError
If worker_info is None

"""
worker_info = get_worker_info() # information specific to each worker process
if worker_info is None:
# LOGGER.error("worker_info is None! Set num_workers > 0 in your dataloader!")
raise RuntimeError
dataset_obj = (
worker_info.dataset
) # the copy of the dataset held by this worker process.
dataset_obj.per_worker_init(
n_workers=worker_info.num_workers,
worker_id=worker_id,
)
Loading