Skip to content

Commit

Permalink
Validate dimension when instantiating datamodel from shape (#395)
Browse files Browse the repository at this point in the history
  • Loading branch information
emolter authored Feb 11, 2025
1 parent 192df47 commit 14005e8
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 1 deletion.
1 change: 1 addition & 0 deletions changes/395.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Validate dimension against schema when instantiating datamodel from array shape
2 changes: 1 addition & 1 deletion src/stdatamodels/jwst/datamodels/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class _NonstandardPrimaryArrayModel(JwstDataModel):
def get_primary_array_name(self):
return "wavelength"

m = _NonstandardPrimaryArrayModel((10,))
m = _NonstandardPrimaryArrayModel((10, 10))
assert "wavelength" in list(m.keys())
assert m.wavelength.sum() == 0

Expand Down
29 changes: 29 additions & 0 deletions src/stdatamodels/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def _make_default_array(attr, schema, ctx):
if attr == primary_array_name:
if ctx.shape is not None:
shape = ctx.shape
_validate_primary_shape(schema, shape)
elif ndim is not None:
shape = tuple([0] * ndim)
else:
Expand Down Expand Up @@ -222,6 +223,34 @@ def _make_default_array(attr, schema, ctx):
return array


def _validate_primary_shape(schema, shape):
"""
Ensure requested shape is allowed by schema.
Parameters
----------
schema : dict
The schema for the primary array.
shape : tuple
The requested shape of the default array.
Raises
------
ValueError
If the requested has dimensions different from the schema's ndim,
or larger than the schema's max_ndim.
"""
ndim_requested = len(shape)
max_ndim = schema.get("max_ndim", None)
ndim = schema.get("ndim", None)
if (ndim is not None) and (ndim_requested != ndim):
msg = f"Array has wrong number of dimensions. Expected {ndim}, got {ndim_requested}"
raise ValueError(msg)
if (max_ndim is not None) and (ndim_requested > max_ndim):
msg = f"Array has wrong number of dimensions. Expected <= {max_ndim}, got {ndim_requested}"
raise ValueError(msg)


def _make_default(attr, schema, ctx):
if "max_ndim" in schema or "ndim" in schema or "datatype" in schema:
return _make_default_array(attr, schema, ctx)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,22 @@ def test_init_with_array2():
dm.data # noqa: B018


def test_init_invalid_shape():
"""Requested some number of dimensions unequal to ndim, which is set to 2"""
with pytest.raises(ValueError):
BasicModel((50,))


def test_init_invalid_shape2():
"""Requested more dimensions than max_ndim"""
with BasicModel() as dm:
schema = dm._schema
schema["properties"]["data"]["max_ndim"] = 1
schema["properties"]["data"]["ndim"] = None
with pytest.raises(ValueError):
BasicModel((50, 50), schema=schema)


def test_set_array():
with pytest.raises(ValueError):
with BasicModel() as dm:
Expand Down

0 comments on commit 14005e8

Please sign in to comment.