Skip to content

Commit

Permalink
Use own errors for Equation
Browse files Browse the repository at this point in the history
  • Loading branch information
glatterf42 committed Oct 3, 2024
1 parent 2358835 commit 0ce1557
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 19 deletions.
4 changes: 3 additions & 1 deletion ixmp4/data/db/optimization/equation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy.orm import validates

from ixmp4 import db
from ixmp4.core.exceptions import OptimizationDataValidationError
from ixmp4.data import types
from ixmp4.data.abstract import optimization as abstract

Expand All @@ -14,6 +15,7 @@ class Equation(base.BaseModel):
# NOTE: These might be mixin-able, but would require some abstraction
NotFound: ClassVar = abstract.Equation.NotFound
NotUnique: ClassVar = abstract.Equation.NotUnique
DataInvalid: ClassVar = OptimizationDataValidationError
DeletionPrevented: ClassVar = abstract.Equation.DeletionPrevented

# constrained_to_indexsets: ClassVar[list[str] | None] = None
Expand All @@ -30,7 +32,7 @@ def validate_data(self, key, data: dict[str, Any]):
del data_to_validate["levels"]
del data_to_validate["marginals"]
_ = utils.validate_data(
key=key,
host=self,
data=data_to_validate,
columns=self.columns,
)
Expand Down
24 changes: 18 additions & 6 deletions ixmp4/data/db/optimization/equation/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

from ixmp4 import db
from ixmp4.core.exceptions import OptimizationItemUsageError
from ixmp4.data.abstract import optimization as abstract
from ixmp4.data.auth.decorators import guard

Expand All @@ -19,6 +20,8 @@ class EquationRepository(
):
model_class = Equation

UsageError = OptimizationItemUsageError

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.docs = EquationDocsRepository(*args, **kwargs)
Expand Down Expand Up @@ -111,7 +114,8 @@ def create(
if isinstance(constrained_to_indexsets, str):
constrained_to_indexsets = list(constrained_to_indexsets)
if column_names and len(column_names) != len(constrained_to_indexsets):
raise ValueError(
raise OptimizationItemUsageError(
f"While processing Equation {name}: \n"
"`constrained_to_indexsets` and `column_names` not equal in length! "
"Please provide the same number of entries for both!"
)
Expand All @@ -120,7 +124,10 @@ def create(
# if len(constrained_to_indexsets) != len(set(constrained_to_indexsets)):
# raise ValueError("Each dimension must be constrained to a unique indexset!") # noqa
if column_names and len(column_names) != len(set(column_names)):
raise ValueError("The given `column_names` are not unique!")
raise OptimizationItemUsageError(
f"While processing Equation {name}: \n"
"The given `column_names` are not unique!"
)

equation = super().create(
run_id=run_id,
Expand Down Expand Up @@ -148,13 +155,18 @@ def tabulate(self, *args, **kwargs) -> pd.DataFrame:
@guard("edit")
def add_data(self, equation_id: int, data: dict[str, Any] | pd.DataFrame) -> None:
if isinstance(data, dict):
data = pd.DataFrame.from_dict(data=data)
try:
data = pd.DataFrame.from_dict(data=data)
except ValueError as e:
raise Equation.DataInvalid(str(e)) from e
equation = self.get_by_id(id=equation_id)

missing_columns = set(["levels", "marginals"]) - set(data.columns)
assert (
not missing_columns
), f"Equation.data must include the column(s): {', '.join(missing_columns)}!"
if missing_columns:
raise OptimizationItemUsageError(
f"Equation.data must include the column(s): "
f"{', '.join(missing_columns)}!"
)

index_list = [column.name for column in equation.columns]
existing_data = pd.DataFrame(equation.data)
Expand Down
24 changes: 18 additions & 6 deletions tests/core/test_optimization_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

import ixmp4
from ixmp4.core import Equation, IndexSet
from ixmp4.core.exceptions import (
OptimizationDataValidationError,
OptimizationItemUsageError,
)

from ..utils import assert_unordered_equality, create_indexsets_for_run

Expand Down Expand Up @@ -60,7 +64,7 @@ def test_create_equation(self, platform: ixmp4.Platform):
)

# Test mismatch in constrained_to_indexsets and column_names raises
with pytest.raises(ValueError, match="not equal in length"):
with pytest.raises(OptimizationItemUsageError, match="not equal in length"):
_ = run.optimization.equations.create(
"Equation 2",
constrained_to_indexsets=[indexset.name],
Expand All @@ -76,7 +80,9 @@ def test_create_equation(self, platform: ixmp4.Platform):
assert equation_2.columns[0].name == "Column 1"

# Test duplicate column_names raise
with pytest.raises(ValueError, match="`column_names` are not unique"):
with pytest.raises(
OptimizationItemUsageError, match="`column_names` are not unique"
):
_ = run.optimization.equations.create(
name="Equation 3",
constrained_to_indexsets=[indexset.name, indexset.name],
Expand Down Expand Up @@ -148,7 +154,8 @@ def test_equation_add_data(self, platform: ixmp4.Platform):
)

with pytest.raises(
AssertionError, match=r"must include the column\(s\): marginals!"
OptimizationItemUsageError,
match=r"must include the column\(s\): marginals!",
):
equation_2.add(
pd.DataFrame(
Expand All @@ -161,7 +168,7 @@ def test_equation_add_data(self, platform: ixmp4.Platform):
)

with pytest.raises(
AssertionError, match=r"must include the column\(s\): levels!"
OptimizationItemUsageError, match=r"must include the column\(s\): levels!"
):
equation_2.add(
data=pd.DataFrame(
Expand All @@ -175,7 +182,10 @@ def test_equation_add_data(self, platform: ixmp4.Platform):

# By converting data to pd.DataFrame, we automatically enforce equal length
# of new columns, raises All arrays must be of the same length otherwise:
with pytest.raises(ValueError, match="All arrays must be of the same length"):
with pytest.raises(
OptimizationDataValidationError,
match="All arrays must be of the same length",
):
equation_2.add(
data={
indexset.name: ["foo", "foo"],
Expand All @@ -185,7 +195,9 @@ def test_equation_add_data(self, platform: ixmp4.Platform):
},
)

with pytest.raises(ValueError, match="contains duplicate rows"):
with pytest.raises(
OptimizationDataValidationError, match="contains duplicate rows"
):
equation_2.add(
data={
indexset.name: ["foo", "foo"],
Expand Down
24 changes: 18 additions & 6 deletions tests/data/test_optimization_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import pytest

import ixmp4
from ixmp4.core.exceptions import (
OptimizationDataValidationError,
OptimizationItemUsageError,
)
from ixmp4.data.abstract import Equation

from ..utils import assert_unordered_equality, create_indexsets_for_run
Expand Down Expand Up @@ -58,7 +62,7 @@ def test_create_equation(self, platform: ixmp4.Platform):
)

# Test mismatch in constrained_to_indexsets and column_names raises
with pytest.raises(ValueError, match="not equal in length"):
with pytest.raises(OptimizationItemUsageError, match="not equal in length"):
_ = platform.backend.optimization.equations.create(
run_id=run.id,
name="Equation 2",
Expand All @@ -76,7 +80,9 @@ def test_create_equation(self, platform: ixmp4.Platform):
assert equation_2.columns[0].name == "Column 1"

# Test duplicate column_names raise
with pytest.raises(ValueError, match="`column_names` are not unique"):
with pytest.raises(
OptimizationItemUsageError, match="`column_names` are not unique"
):
_ = platform.backend.optimization.equations.create(
run_id=run.id,
name="Equation 3",
Expand Down Expand Up @@ -160,7 +166,7 @@ def test_equation_add_data(self, platform: ixmp4.Platform):
)

with pytest.raises(
AssertionError, match=r"must include the column\(s\): levels!"
OptimizationItemUsageError, match=r"must include the column\(s\): levels!"
):
platform.backend.optimization.equations.add_data(
equation_id=equation_2.id,
Expand All @@ -174,7 +180,8 @@ def test_equation_add_data(self, platform: ixmp4.Platform):
)

with pytest.raises(
AssertionError, match=r"must include the column\(s\): marginals!"
OptimizationItemUsageError,
match=r"must include the column\(s\): marginals!",
):
platform.backend.optimization.equations.add_data(
equation_id=equation_2.id,
Expand All @@ -189,7 +196,10 @@ def test_equation_add_data(self, platform: ixmp4.Platform):

# By converting data to pd.DataFrame, we automatically enforce equal length
# of new columns, raises All arrays must be of the same length otherwise:
with pytest.raises(ValueError, match="All arrays must be of the same length"):
with pytest.raises(
OptimizationDataValidationError,
match="All arrays must be of the same length",
):
platform.backend.optimization.equations.add_data(
equation_id=equation_2.id,
data={
Expand All @@ -200,7 +210,9 @@ def test_equation_add_data(self, platform: ixmp4.Platform):
},
)

with pytest.raises(ValueError, match="contains duplicate rows"):
with pytest.raises(
OptimizationDataValidationError, match="contains duplicate rows"
):
platform.backend.optimization.equations.add_data(
equation_id=equation_2.id,
data={
Expand Down

0 comments on commit 0ce1557

Please sign in to comment.