From a95cd2f51cd869b46dc0519669f6a8f69fa5b3e0 Mon Sep 17 00:00:00 2001 From: Timothy Smith Date: Tue, 22 Mar 2022 18:25:30 -0500 Subject: [PATCH 1/7] like extra_variables, but no time stamp... --- xmitgcm/mds_store.py | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/xmitgcm/mds_store.py b/xmitgcm/mds_store.py index ea3730f..697de54 100644 --- a/xmitgcm/mds_store.py +++ b/xmitgcm/mds_store.py @@ -59,7 +59,8 @@ def open_mdsdataset(data_dir, grid_dir=None, ignore_unknown_vars=False, default_dtype=None, nx=None, ny=None, nz=None, llc_method="smallchunks", extra_metadata=None, - extra_variables=None): + extra_variables=None, + custom_grid_variables={}): """Open MITgcm-style mds (.data / .meta) file output as xarray datset. Parameters @@ -148,6 +149,8 @@ def open_mdsdataset(data_dir, grid_dir=None, standard_name='Sensitivity_to_theta', long_name='Sensitivity of cost function to theta', units='[J]/degC')) ) + custom_grid_variables : dict, optional + Similar to extra_variables, but these files don't have a time stamp. Returns @@ -235,7 +238,8 @@ def open_mdsdataset(data_dir, grid_dir=None, default_dtype=default_dtype, nx=nx, ny=ny, nz=nz, llc_method=llc_method, levels=levels, extra_metadata=extra_metadata, - extra_variables=extra_variables) + extra_variables=extra_variables, + custom_grid_variables=custom_grid_variables) datasets = [open_mdsdataset( data_dir, iters=iternum, read_grid=False, **kwargs) for iternum in iters] @@ -291,7 +295,8 @@ def open_mdsdataset(data_dir, grid_dir=None, default_dtype=default_dtype, nx=nx, ny=ny, nz=nz, llc_method=llc_method, levels=levels, extra_metadata=extra_metadata, - extra_variables=extra_variables) + extra_variables=extra_variables, + custom_grid_variables=custom_grid_variables) ds = xr.Dataset.load_store(store) if swap_dims: @@ -376,7 +381,8 @@ def __init__(self, data_dir, grid_dir=None, default_dtype=np.dtype('f4'), nx=None, ny=None, nz=None, llc_method="smallchunks", levels=None, extra_metadata=None, - extra_variables=None): + extra_variables=None, + custom_grid_variables={}): """ This is not a user-facing class. See open_mdsdataset for argument documentation. The only ones which are distinct are. @@ -401,6 +407,7 @@ def __init__(self, data_dir, grid_dir=None, self.data_dir = data_dir self.grid_dir = grid_dir if (grid_dir is not None) else data_dir self.extra_variables = extra_variables + self.custom_grid_variables = custom_grid_variables self._ignore_unknown_vars = ignore_unknown_vars # The endianness of the files @@ -573,7 +580,8 @@ def __init__(self, data_dir, grid_dir=None, # build lookup tables for variable metadata self._all_grid_variables = _get_all_grid_variables(self.geometry, self.grid_dir, - self.layers) + self.layers, + self.custom_grid_variables) self._all_data_variables = _get_all_data_variables(self.data_dir, self.grid_dir, self.layers, @@ -831,7 +839,7 @@ def _guess_layers(data_dir): return all_layers -def _get_all_grid_variables(geometry, grid_dir=None, layers={}): +def _get_all_grid_variables(geometry, grid_dir=None, layers={}, custom_grid_variables={}): """"Put all the relevant grid metadata into one big dictionary.""" possible_hcoords = {'cartesian': horizontal_coordinates_cartesian, 'llc': horizontal_coordinates_llc, @@ -841,7 +849,7 @@ def _get_all_grid_variables(geometry, grid_dir=None, layers={}): hcoords = possible_hcoords[geometry] # look for extra variables, if they exist in grid_dir - extravars = _get_extra_grid_variables(grid_dir) if grid_dir is not None else {} + extravars = _get_extra_grid_variables(grid_dir, custom_grid_variables=custom_grid_variables) if grid_dir is not None else {} allvars = [hcoords, vertical_coordinates, horizontal_grid_variables, vertical_grid_variables, volume_grid_variables, mask_variables, @@ -856,21 +864,22 @@ def _get_all_grid_variables(geometry, grid_dir=None, layers={}): return metadata -def _get_extra_grid_variables(grid_dir): +def _get_extra_grid_variables(grid_dir, custom_grid_variables): """Scan a directory and return all file prefixes for extra grid files. Then return the variable information for each of these""" extra_grid = {} - fnames = dict([[val['filename'],key] for key,val in extra_grid_variables.items() if 'filename' in val]) + all_extras = {**extra_grid_variables, **custom_grid_variables} + fnames = dict([[val['filename'],key] for key,val in all_extras.items() if 'filename' in val]) all_datafiles = listdir_endswith(grid_dir, '.data') for f in all_datafiles: prefix = os.path.split(f[:-5])[-1] - # Only consider what we find that matches extra_grid_vars - if prefix in extra_grid_variables: - extra_grid[prefix] = extra_grid_variables[prefix] + # Only consider what we find that matches extra/custom_grid_vars + if prefix in all_extras: + extra_grid[prefix] = all_extras[prefix] elif prefix in fnames: - extra_grid[fnames[prefix]] = extra_grid_variables[fnames[prefix]] + extra_grid[fnames[prefix]] = all_extras[fnames[prefix]] return extra_grid From 9b0bfdd7a1d784f1c4d76256089810323d38a40d Mon Sep 17 00:00:00 2001 From: Timothy Smith Date: Fri, 8 Apr 2022 14:41:46 -0500 Subject: [PATCH 2/7] stash commit ... --- xmitgcm/llcreader/known_models.py | 14 ++++++++++++++ xmitgcm/llcreader/llcmodel.py | 22 ++++++++++++++++------ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/xmitgcm/llcreader/known_models.py b/xmitgcm/llcreader/known_models.py index 7b46da2..d725898 100644 --- a/xmitgcm/llcreader/known_models.py +++ b/xmitgcm/llcreader/known_models.py @@ -267,3 +267,17 @@ def __init__(self): shrunk=True, join_char='/') super(SverdrupASTE270Model, self).__init__(store) + +#class SverdrupSmoothLLC90Model(SmoothLLC90Model): +# @_requires_sverdrup +# def __init__(self): +# from fsspec.implementations.local import LocalFileSystem +# fs = LocalFileSystem() +# base_path = '/scratch2/shared/aste-release1/diags' +# grid_path = '/scratch2/shared/aste-release1/grid' +# mask_path = '/scratch2/shared/aste-release1/masks.zarr' +# store = stores.NestedStore(fs, base_path=base_path, grid_path=grid_path, +# mask_path=mask_path, +# shrunk=True, join_char='/') +# +# super(SverdrupASTE270Model, self).__init__(store) diff --git a/xmitgcm/llcreader/llcmodel.py b/xmitgcm/llcreader/llcmodel.py index edce045..956a570 100644 --- a/xmitgcm/llcreader/llcmodel.py +++ b/xmitgcm/llcreader/llcmodel.py @@ -50,6 +50,16 @@ def _get_var_metadata(): var_metadata = state_variables.copy() var_metadata.update(package_state_variables) var_metadata.update(available_diags) + extra_variables = { + f'smooth3Dfld001': { + 'dims': ['k', 'j', 'i'], + 'attrs': { + 'standard_name': 'smooth_fld', + 'long_name': r'$C\mathbf{z}$', + } + }, + } + var_metadata.update(extra_variables) # even the file names from the LLC data differ from standard MITgcm output aliases = {'Eta': 'ETAN', 'PhiBot': 'PHIBOT', 'Salt': 'SALT', @@ -776,14 +786,14 @@ def _check_iters(self, iters): if not set(iters) <= set(self.iters): msg = "Some requested iterations may not exist, you may need to change 'iters'" warnings.warn(msg, RuntimeWarning) - + elif self.iter_start is not None and self.iter_step is not None: for iter in iters: if (iter - self.iter_start) % self.iter_step: msg = "Some requested iterations may not exist, you may need to change 'iters'" warnings.warn(msg, RuntimeWarning) break - + def get_dataset(self, varnames=None, iter_start=None, iter_stop=None, iter_step=None, iters=None, k_levels=None, k_chunksize=1, @@ -838,7 +848,7 @@ def _if_not_none(a, b): if iters is not None: raise ValueError("Only `iters` or the parameters `iter_start`, `iters_stop`, " "and `iter_step` can be provided. Both were provided") - + # Otherwise we can override any missing values iter_start = _if_not_none(iter_start, self.iter_start) iter_stop = _if_not_none(iter_stop, self.iter_stop) @@ -849,12 +859,12 @@ def _if_not_none(a, b): "and `iter_step` must be defined either by the " "model class or as argument. Instead got %r " % iter_params) - + # Otherwise try loading from the user set iters elif iters is not None: pass - # Now have a go at using the attribute derived iteration parameters + # Now have a go at using the attribute derived iteration parameters elif all([a is not None for a in attribute_iter_params]): iter_params = attribute_iter_params @@ -867,7 +877,7 @@ def _if_not_none(a, b): raise ValueError("The parameters `iter_start`, `iter_stop`, " "and `iter_step`, or `iters` must be defined either by the " "model class or as argument") - + # Check the iter_start and iter_step if iters is None: self._check_iter_start(iter_params[0]) From e04edffc06e081032b90c452451d5ff3992ba981 Mon Sep 17 00:00:00 2001 From: Timothy Smith Date: Wed, 13 Apr 2022 10:04:45 -0500 Subject: [PATCH 3/7] small cleanup --- xmitgcm/llcreader/known_models.py | 14 --- xmitgcm/llcreader/llcmodel.py | 151 ++++++++++++++++-------------- xmitgcm/mds_store.py | 15 +-- 3 files changed, 89 insertions(+), 91 deletions(-) diff --git a/xmitgcm/llcreader/known_models.py b/xmitgcm/llcreader/known_models.py index d725898..7b46da2 100644 --- a/xmitgcm/llcreader/known_models.py +++ b/xmitgcm/llcreader/known_models.py @@ -267,17 +267,3 @@ def __init__(self): shrunk=True, join_char='/') super(SverdrupASTE270Model, self).__init__(store) - -#class SverdrupSmoothLLC90Model(SmoothLLC90Model): -# @_requires_sverdrup -# def __init__(self): -# from fsspec.implementations.local import LocalFileSystem -# fs = LocalFileSystem() -# base_path = '/scratch2/shared/aste-release1/diags' -# grid_path = '/scratch2/shared/aste-release1/grid' -# mask_path = '/scratch2/shared/aste-release1/masks.zarr' -# store = stores.NestedStore(fs, base_path=base_path, grid_path=grid_path, -# mask_path=mask_path, -# shrunk=True, join_char='/') -# -# super(SverdrupASTE270Model, self).__init__(store) diff --git a/xmitgcm/llcreader/llcmodel.py b/xmitgcm/llcreader/llcmodel.py index 956a570..81f6349 100644 --- a/xmitgcm/llcreader/llcmodel.py +++ b/xmitgcm/llcreader/llcmodel.py @@ -37,7 +37,7 @@ def _get_grid_metadata(): return grid_metadata -def _get_var_metadata(): +def _get_var_metadata(extra_variables=None): # The LLC run data comes with zero metadata. So we import metadata from # the xmitgcm package. from ..variables import state_variables, package_state_variables @@ -50,16 +50,8 @@ def _get_var_metadata(): var_metadata = state_variables.copy() var_metadata.update(package_state_variables) var_metadata.update(available_diags) - extra_variables = { - f'smooth3Dfld001': { - 'dims': ['k', 'j', 'i'], - 'attrs': { - 'standard_name': 'smooth_fld', - 'long_name': r'$C\mathbf{z}$', - } - }, - } - var_metadata.update(extra_variables) + if extra_variables is not None: + var_metadata.update(extra_variables) # even the file names from the LLC data differ from standard MITgcm output aliases = {'Eta': 'ETAN', 'PhiBot': 'PHIBOT', 'Salt': 'SALT', @@ -72,54 +64,7 @@ def _get_var_metadata(): return var_metadata -_VAR_METADATA = _get_var_metadata() - -def _is_vgrid(vname): - # check for 1d, vertical grid variables - dims = _VAR_METADATA[vname]['dims'] - return len(dims)==1 and dims[0][0]=='k' - -def _get_variable_point(vname, mask_override): - # fix for https://github.com/MITgcm/xmitgcm/issues/191 - if vname in mask_override: - return mask_override[vname] - dims = _VAR_METADATA[vname]['dims'] - if 'i' in dims and 'j' in dims: - point = 'c' - elif 'i_g' in dims and 'j' in dims: - point = 'w' - elif 'i' in dims and 'j_g' in dims: - point = 's' - elif 'i_g' in dims and 'j_g' in dims: - raise ValueError("Don't have masks for corner points!") - else: - raise ValueError("Variable `%s` is not a horizontal variable." % vname) - return point - -def _get_scalars_and_vectors(varnames, type): - - for vname in varnames: - if vname not in _VAR_METADATA: - raise ValueError("Varname `%s` not found in metadata." % vname) - - if type != 'latlon': - return varnames, [] - scalars = [] - vector_pairs = [] - for vname in varnames: - meta = _VAR_METADATA[vname] - try: - mate = meta['attrs']['mate'] - if mate not in varnames: - raise ValueError("Vector pairs are required to create " - "latlon type datasets. Varname `%s` is " - "missing its vector mate `%s`" - % vname, mate) - vector_pairs.append((vname, mate)) - varnames.remove(mate) - except KeyError: - scalars.append(vname) def _decompress(data, mask, dtype): data_blank = np.full_like(mask, np.nan, dtype=dtype) @@ -604,6 +549,7 @@ class BaseLLCModel: varnames = [] grid_varnames = [] mask_override = {} + var_metadata = None domain = 'global' pad_before = [0]*_nfacets pad_after = [0]*_nfacets @@ -642,6 +588,53 @@ def _dtype(self,varname=None): elif isinstance(self.dtype,dict): return np.dtype(self.dtype[varname]) + def _is_vgrid(self, vname): + # check for 1d, vertical grid variables + dims = self.var_metadata[vname]['dims'] + return len(dims)==1 and dims[0][0]=='k' + + def _get_variable_point(self, vname, mask_override): + # fix for https://github.com/MITgcm/xmitgcm/issues/191 + if vname in mask_override: + return mask_override[vname] + dims = self.var_metadata[vname]['dims'] + if 'i' in dims and 'j' in dims: + point = 'c' + elif 'i_g' in dims and 'j' in dims: + point = 'w' + elif 'i' in dims and 'j_g' in dims: + point = 's' + elif 'i_g' in dims and 'j_g' in dims: + raise ValueError("Don't have masks for corner points!") + else: + raise ValueError("Variable `%s` is not a horizontal variable." % vname) + return point + + def _get_scalars_and_vectors(self, varnames, type): + + for vname in varnames: + if vname not in self.var_metadata: + raise ValueError("Varname `%s` not found in metadata." % vname) + + if type != 'latlon': + return varnames, [] + + scalars = [] + vector_pairs = [] + for vname in varnames: + meta = self.var_metadata[vname] + try: + mate = meta['attrs']['mate'] + if mate not in varnames: + raise ValueError("Vector pairs are required to create " + "latlon type datasets. Varname `%s` is " + "missing its vector mate `%s`" + % vname, mate) + vector_pairs.append((vname, mate)) + varnames.remove(mate) + except KeyError: + scalars.append(vname) + def _get_kp1_levels(self,k_levels): # determine kp1 levels # get borders to all k (center) levels @@ -739,7 +732,7 @@ def _dask_array_vgrid(self, varname, klevels, k_chunksize): name = '-'.join([varname, token]) dtype = self._dtype(varname) - nz = self.nz if _VAR_METADATA[varname]['dims'] != ['k_p1'] else self.nz+1 + nz = self.nz if self.var_metadata[varname]['dims'] != ['k_p1'] else self.nz+1 task = (_get_1d_chunk, self.store, varname, list(klevels), nz, dtype) @@ -750,12 +743,12 @@ def _dask_array_vgrid(self, varname, klevels, k_chunksize): def _get_facet_data(self, varname, iters, klevels, k_chunksize): # needs facets to be outer index of nested lists - dims = _VAR_METADATA[varname]['dims'] + dims = self.var_metadata[varname]['dims'] if len(dims)==2: klevels = [0,] - if _is_vgrid(varname): + if self._is_vgrid(varname): data_facets = self._dask_array_vgrid(varname,klevels,k_chunksize) else: data_facets = [self._dask_array(nfacet, varname, iters, klevels, k_chunksize) @@ -797,7 +790,8 @@ def _check_iters(self, iters): def get_dataset(self, varnames=None, iter_start=None, iter_stop=None, iter_step=None, iters=None, k_levels=None, k_chunksize=1, - type='faces', read_grid=True, grid_vars_to_coords=True): + type='faces', read_grid=True, grid_vars_to_coords=True, + extra_variables=None): """ Create an xarray Dataset object for this model. @@ -827,6 +821,22 @@ def get_dataset(self, varnames=None, iter_start=None, iter_stop=None, Whether to read the grid info grid_vars_to_coords : bool, optional Whether to promote grid variables to coordinate status + extra_variables : dict, optional + Allow to pass variables not listed in the variables.py + or in available_diagnostics.log. + extra_variables must be a dict containing the variable names as keys with + the corresponging values being a dict with the keys being dims and attrs. + + Syntax: + extra_variables = dict(varname = dict(dims=list_of_dims, attrs=dict(optional_attrs))) + where optional_attrs can contain standard_name, long_name, units as keys + + Example: + extra_variables = dict( + ADJtheta = dict(dims=['k','j','i'], attrs=dict( + standard_name='Sensitivity_to_theta', + long_name='Sensitivity of cost function to theta', units='[J]/degC')) + ) Returns ------- @@ -839,6 +849,7 @@ def _if_not_none(a, b): else: return a + self.var_metadata = _get_var_metadata(extra_variables=extra_variables) user_iter_params = [iter_start, iter_stop, iter_step] attribute_iter_params = [self.iter_start, self.iter_stop, self.iter_step] @@ -916,7 +927,7 @@ def _if_not_none(a, b): # do separately for vertical coords on kp1_levels grid_facets = {} for vname in grid_varnames: - my_k_levels = k_levels if _VAR_METADATA[vname]['dims'] !=['k_p1'] else kp1_levels + my_k_levels = k_levels if self.var_metadata[vname]['dims'] !=['k_p1'] else kp1_levels grid_facets[vname] = self._get_facet_data(vname, None, my_k_levels, k_chunksize) # transform it into faces or latlon @@ -924,22 +935,22 @@ def _if_not_none(a, b): 'latlon': _all_facets_to_latlon} transformer = data_transformers[type] - data = transformer(data_facets, _VAR_METADATA, self.nface) + data = transformer(data_facets, self.var_metadata, self.nface) # separate horizontal and vertical grid variables hgrid_facets = {key: grid_facets[key] - for key in grid_varnames if not _is_vgrid(key)} + for key in grid_varnames if not self._is_vgrid(key)} vgrid_facets = {key: grid_facets[key] - for key in grid_varnames if _is_vgrid(key)} + for key in grid_varnames if self._is_vgrid(key)} # do not transform vertical grid variables - data.update(transformer(hgrid_facets, _VAR_METADATA, self.nface)) + data.update(transformer(hgrid_facets, self.var_metadata, self.nface)) data.update(vgrid_facets) variables = {} gridlist = ['Zl','Zu'] if read_grid else [] for vname in varnames+grid_varnames: - meta = _VAR_METADATA[vname] + meta = self.var_metadata[vname] dims = meta['dims'] if type=='faces': dims = _add_face_to_dims(dims) @@ -958,9 +969,9 @@ def _if_not_none(a, b): if read_grid and 'RF' in grid_varnames: ki = np.array([list(kp1_levels).index(x) for x in k_levels]) for zv,sl in zip(['Zl','Zu'],[ki,ki+1]): - variables[zv] = xr.Variable(_VAR_METADATA[zv]['dims'], + variables[zv] = xr.Variable(self.var_metadata[zv]['dims'], data['RF'][sl], - _VAR_METADATA[zv]['attrs']) + self.var_metadata[zv]['attrs']) ds = ds.update(variables) diff --git a/xmitgcm/mds_store.py b/xmitgcm/mds_store.py index 697de54..7678006 100644 --- a/xmitgcm/mds_store.py +++ b/xmitgcm/mds_store.py @@ -60,7 +60,7 @@ def open_mdsdataset(data_dir, grid_dir=None, nx=None, ny=None, nz=None, llc_method="smallchunks", extra_metadata=None, extra_variables=None, - custom_grid_variables={}): + custom_grid_variables=None): """Open MITgcm-style mds (.data / .meta) file output as xarray datset. Parameters @@ -382,7 +382,7 @@ def __init__(self, data_dir, grid_dir=None, nx=None, ny=None, nz=None, llc_method="smallchunks", levels=None, extra_metadata=None, extra_variables=None, - custom_grid_variables={}): + custom_grid_variables=None): """ This is not a user-facing class. See open_mdsdataset for argument documentation. The only ones which are distinct are. @@ -869,17 +869,18 @@ def _get_extra_grid_variables(grid_dir, custom_grid_variables): Then return the variable information for each of these""" extra_grid = {} - all_extras = {**extra_grid_variables, **custom_grid_variables} - fnames = dict([[val['filename'],key] for key,val in all_extras.items() if 'filename' in val]) + if custom_grid_variables is not None: + extra_grid_variables = extra_grid_variables.update(custom_grid_variables) + fnames = dict([[val['filename'],key] for key,val in extra_grid_variables.items() if 'filename' in val]) all_datafiles = listdir_endswith(grid_dir, '.data') for f in all_datafiles: prefix = os.path.split(f[:-5])[-1] # Only consider what we find that matches extra/custom_grid_vars - if prefix in all_extras: - extra_grid[prefix] = all_extras[prefix] + if prefix in extra_grid_variables: + extra_grid[prefix] = extra_grid_variables[prefix] elif prefix in fnames: - extra_grid[fnames[prefix]] = all_extras[fnames[prefix]] + extra_grid[fnames[prefix]] = extra_grid_variables[fnames[prefix]] return extra_grid From 0b550e711ebeebe792fe7219cc738183610cf649 Mon Sep 17 00:00:00 2001 From: Timothy Smith Date: Mon, 27 Jun 2022 10:14:36 -0500 Subject: [PATCH 4/7] none is a better default --- xmitgcm/mds_store.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xmitgcm/mds_store.py b/xmitgcm/mds_store.py index 7678006..665f9b7 100644 --- a/xmitgcm/mds_store.py +++ b/xmitgcm/mds_store.py @@ -839,7 +839,7 @@ def _guess_layers(data_dir): return all_layers -def _get_all_grid_variables(geometry, grid_dir=None, layers={}, custom_grid_variables={}): +def _get_all_grid_variables(geometry, grid_dir=None, layers={}, custom_grid_variables=None): """"Put all the relevant grid metadata into one big dictionary.""" possible_hcoords = {'cartesian': horizontal_coordinates_cartesian, 'llc': horizontal_coordinates_llc, @@ -870,7 +870,8 @@ def _get_extra_grid_variables(grid_dir, custom_grid_variables): extra_grid = {} if custom_grid_variables is not None: - extra_grid_variables = extra_grid_variables.update(custom_grid_variables) + extra_grid_variables.update(custom_grid_variables) + fnames = dict([[val['filename'],key] for key,val in extra_grid_variables.items() if 'filename' in val]) all_datafiles = listdir_endswith(grid_dir, '.data') From d03b3c6ed45b0152d80749727e7ab204ec9af1e0 Mon Sep 17 00:00:00 2001 From: Timothy Smith Date: Mon, 27 Jun 2022 10:40:54 -0600 Subject: [PATCH 5/7] pull _get_variable_point back outside llcmodel class --- xmitgcm/llcreader/llcmodel.py | 39 ++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/xmitgcm/llcreader/llcmodel.py b/xmitgcm/llcreader/llcmodel.py index 81f6349..a33dbaf 100644 --- a/xmitgcm/llcreader/llcmodel.py +++ b/xmitgcm/llcreader/llcmodel.py @@ -65,6 +65,22 @@ def _get_var_metadata(extra_variables=None): return var_metadata +def _get_variable_point(vname, dims, mask_override): + # fix for https://github.com/MITgcm/xmitgcm/issues/191 + if vname in mask_override: + return mask_override[vname] + if 'i' in dims and 'j' in dims: + point = 'c' + elif 'i_g' in dims and 'j' in dims: + point = 'w' + elif 'i' in dims and 'j_g' in dims: + point = 's' + elif 'i_g' in dims and 'j_g' in dims: + raise ValueError("Don't have masks for corner points!") + else: + raise ValueError("Variable `%s` is not a horizontal variable." % vname) + return point + def _decompress(data, mask, dtype): data_blank = np.full_like(mask, np.nan, dtype=dtype) @@ -405,7 +421,7 @@ def _chunks(l, n): def _get_facet_chunk(store, varname, iternum, nfacet, klevels, nx, nz, nfaces, - dtype, mask_override, domain, pad_before, pad_after): + dtype, mask_override, domain, pad_before, pad_after, dims): fs, path = store.get_fs_and_full_path(varname, iternum) @@ -423,7 +439,7 @@ def _get_facet_chunk(store, varname, iternum, nfacet, klevels, nx, nz, nfaces, if (store.shrunk and iternum is not None) or \ (store.shrunk_grid and iternum is None): # the store tells us whether we need a mask or not - point = _get_variable_point(varname, mask_override) + point = _get_variable_point(varname, dims, mask_override) mykey = nx if domain == 'global' else f'{domain}_{nx}' index = all_index_data[mykey][point] zgroup = store.open_mask_group() @@ -593,22 +609,6 @@ def _is_vgrid(self, vname): dims = self.var_metadata[vname]['dims'] return len(dims)==1 and dims[0][0]=='k' - def _get_variable_point(self, vname, mask_override): - # fix for https://github.com/MITgcm/xmitgcm/issues/191 - if vname in mask_override: - return mask_override[vname] - dims = self.var_metadata[vname]['dims'] - if 'i' in dims and 'j' in dims: - point = 'c' - elif 'i_g' in dims and 'j' in dims: - point = 'w' - elif 'i' in dims and 'j_g' in dims: - point = 's' - elif 'i_g' in dims and 'j_g' in dims: - raise ValueError("Don't have masks for corner points!") - else: - raise ValueError("Variable `%s` is not a horizontal variable." % vname) - return point def _get_scalars_and_vectors(self, varnames, type): @@ -701,10 +701,11 @@ def _key_and_task(n_k, these_klevels, n_iter=None, iternum=None): key = name, n_k, 0, 0, 0 else: key = name, n_iter, n_k, 0, 0, 0 + dims = self.var_metadata[varname]['dims'] task = (_get_facet_chunk, self.store, varname, iternum, nfacet, these_klevels, self.nx, self.nz, self.nface, dtype, self.mask_override, self.domain, - self.pad_before, self.pad_after) + self.pad_before, self.pad_after, dims) return key, task if iters is not None: From 554565cbde5759af65a7facc46a7b815ab6a7b87 Mon Sep 17 00:00:00 2001 From: timothyas Date: Fri, 5 Aug 2022 13:27:29 -0600 Subject: [PATCH 6/7] test for custom_grid_variables --- xmitgcm/test/test_mds_store.py | 35 ++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/xmitgcm/test/test_mds_store.py b/xmitgcm/test/test_mds_store.py index dc7cbe2..25149cc 100644 --- a/xmitgcm/test/test_mds_store.py +++ b/xmitgcm/test/test_mds_store.py @@ -608,6 +608,41 @@ def test_extra_variables(all_mds_datadirs): mate = ds[var].attrs['mate'] assert ds[mate].attrs['mate'] == var +def test_custom_grid_variables(all_mds_datadirs): + """Test that open_mdsdataset reads custom grid variables (i.e. no time stamp) correctly""" + dirname, expected = all_mds_datadirs + + custom_grid_variables = { + "iamgridC" : { + "dims" : ["k", "j", "i"], "attrs": {}, + }, + "iamgridW" : { + "dims" : ["k", "j", "i_g"], "attrs": {}, + }, + "iamgridS" : { + "dims" : ["k", "j_g", "i"], "attrs": {}, + }, + } + + # copy hFac to our new grid variable ... + for suffix in ["C", "W", "S"]: + for ext in [".meta", ".data"]: + fname_in = os.path.join(dirname, f"hFac{suffix}{ext}") + fname_out= os.path.join(dirname, f"iamgrid{suffix}{ext}") + copyfile(fname_in, fname_out) + + ds = xmitgcm.open_mdsdataset( + dirname, + read_grid=True, + iters=None, + geometry=expected["geometry"], + prefix=list(custom_grid_variables.keys()), + custom_grid_variables=custom_grid_variables) + + for var in custom_grid_variables.keys(): + assert var in ds + assert var in ds.coords + def test_mask_values(all_mds_datadirs): """Test that open_mdsdataset generates binary masks with correct values""" From 69833acd8b484c8903b76ae99f8ab23fa7e464cc Mon Sep 17 00:00:00 2001 From: timothyas Date: Wed, 11 Oct 2023 09:15:20 -0600 Subject: [PATCH 7/7] try changing fixture scope... --- xmitgcm/test/test_xmitgcm_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xmitgcm/test/test_xmitgcm_common.py b/xmitgcm/test/test_xmitgcm_common.py index 5c7cbd8..3d44e8e 100644 --- a/xmitgcm/test/test_xmitgcm_common.py +++ b/xmitgcm/test/test_xmitgcm_common.py @@ -278,7 +278,7 @@ def file_md5_checksum(fname): # find the tar archive in the test directory # http://stackoverflow.com/questions/29627341/pytest-where-to-store-expected-data -@pytest.fixture(scope='module', params=_experiments.keys()) +@pytest.fixture(scope='function', params=_experiments.keys()) def all_mds_datadirs(tmpdir_factory, request): return setup_mds_dir(tmpdir_factory, request, _experiments)