diff --git a/pyaerocom/colocation/colocated_data.py b/pyaerocom/colocation/colocated_data.py index c9ec07c6e..e2d983bd1 100644 --- a/pyaerocom/colocation/colocated_data.py +++ b/pyaerocom/colocation/colocated_data.py @@ -34,19 +34,18 @@ logger = logging.getLogger(__name__) -def ensure_correct_dimensions(data: np.ndarray | xr.DataArray): +def ensure_correct_dimensions(data: xr.DataArray): """ - Ensure the dimensions on either a numpy aray or xarray passed to ColocatedData. + Ensure the dimensions on an xarray.DataArray passed to ColocatedData. If a ColocatedData object is created outside of pyaerocom, this checking is needed. This function is used as part of the model validator. """ - shape = data.shape[0] - if isinstance(data, np.ndarray): - num_dims = data.ndim - elif isinstance(data, xr.DataArray): - num_dims = len(data.dims) - else: + if not isinstance(data, xr.DataArray): raise ValueError("Could not interpret data") + + shape = data.shape[0] + num_dims = len(data.dims) + if num_dims not in (2, 3, 4): raise DataDimensionError("invalid input, need 2D, 3D or 4D numpy array") elif not shape == 2: @@ -124,19 +123,19 @@ class ColocatedData(BaseModel): @model_validator(mode="after") def validate_data(self): + if self.data is None: + return self if isinstance(self.data, Path): # make sure path is str instance self.data = str(self.data) if isinstance(self.data, str): - assert self.data.endswith("nc"), ValueError( - "Invalid data filepath str, must point to a .nc file" - ) + if not self.data.endswith("nc"): + raise ValueError( + f"Invalid data filepath str, must point to a .nc file. Got {self.data}" + ) self.open(self.data) - elif isinstance(self.data, xr.DataArray): - ensure_correct_dimensions(self.data) - return self.data - elif isinstance(self.data, np.ndarray): - ensure_correct_dimensions(self.data) + return self + if isinstance(self.data, np.ndarray): if hasattr(self, "model_extra"): da_keys = dir(xr.DataArray) extra_args_from_class_initialization = { @@ -146,6 +145,9 @@ def validate_data(self): extra_args_from_class_initialization = {} data = xr.DataArray(self.data, **extra_args_from_class_initialization) self.data = data + # self.data should be xr.DataArray at this stage + ensure_correct_dimensions(self.data) + return self # Override __init__ to allow for positional arguments def __init__( diff --git a/pyaerocom/colocation/colocation_setup.py b/pyaerocom/colocation/colocation_setup.py index ed74838cc..d0703f18d 100644 --- a/pyaerocom/colocation/colocation_setup.py +++ b/pyaerocom/colocation/colocation_setup.py @@ -479,6 +479,7 @@ def validate_no_forbidden_keys(self): for key in self.FORBIDDEN_KEYS: if key in self.model_fields: raise ValidationError + return self @cached_property def basedir_logfiles(self): @@ -488,25 +489,26 @@ def basedir_logfiles(self): return str(p) @model_validator(mode="after") - @classmethod - def validate_obs_config(cls, v: PyaroConfig): - if v is not None and cls.obs.config.name != cls.obs_id: + def validate_obs_config(self): + if self.obs_config is None: + return self + if self.obs_config.name != self.obs_id: logger.info( - f"Data ID in Pyaro config {v.name} does not match obs_id {cls.obs_id}. Setting Pyaro config to None!" + f"Data ID in Pyaro config {self.obs_config.name} does not match obs_id {self.obs_id}. Setting Pyaro config to None!" ) - v = None - if v is not None: - if isinstance(v, dict): + self.obs_config = None + if self.obs_config is not None: + if isinstance(self.obs_config, dict): logger.info("Obs config was given as dict. Will try to convert to PyaroConfig") - v = PyaroConfig(**v) - if v.name != cls.obs_id: + self.obs_config = PyaroConfig(**self.obs_config) + if self.obs_config.name != self.obs_id: logger.info( - f"Data ID in Pyaro config {v.name} does not match obs_id {cls.obs_id}. Setting Obs ID to match Pyaro Config!" + f"Data ID in Pyaro config {self.obs_config.name} does not match obs_id {self.obs_id}. Setting Obs ID to match Pyaro Config!" ) - cls.obs_id = v.name - if cls.obs_id is None: - cls.obs_id = v.name - return v + self.obs_id = self.obs_config.name + if self.obs_id is None: + self.obs_id = self.obs_config.name + return self def add_glob_meta(self, **kwargs): """ diff --git a/tests/io/test_readungridded.py b/tests/io/test_readungridded.py index 7ffb7014e..8909d94af 100644 --- a/tests/io/test_readungridded.py +++ b/tests/io/test_readungridded.py @@ -58,9 +58,17 @@ def test_ReadUngridded___init__(data_ids, ignore_cache): (dict(station_name="La_Paz"), 1, 1), (dict(station_name=["La_Paz", "AAO*"]), 2, 2), (dict(altitude=[1000, 10000]), 3, 3), - (dict(altitude=[1000, 10000], ignore_station_names=dict(od550aer="La_Paz")), 2, 2), + ( + dict(altitude=[1000, 10000], ignore_station_names=dict(od550aer="La_Paz")), + 2, + 2, + ), (dict(altitude=[1000, 10000], ignore_station_names="La_*"), 2, 2), - (dict(altitude=[1000, 10000], ignore_station_names=["La_*", "Mauna_Loa"]), 1, 1), + ( + dict(altitude=[1000, 10000], ignore_station_names=["La_*", "Mauna_Loa"]), + 1, + 1, + ), ], ) @pytest.mark.parametrize(