diff --git a/ixmp4/core/optimization/data.py b/ixmp4/core/optimization/data.py index c604dbe1..d3f8cbfb 100644 --- a/ixmp4/core/optimization/data.py +++ b/ixmp4/core/optimization/data.py @@ -29,3 +29,9 @@ def __init__(self, run: Run, **kwargs: Backend) -> None: self.scalars = ScalarRepository(_backend=self.backend, _run=run) self.tables = TableRepository(_backend=self.backend, _run=run) self.variables = VariableRepository(_backend=self.backend, _run=run) + + def remove_solution(self) -> None: + for equation in self.equations.list(): + equation.remove_data() + for variable in self.variables.list(): + variable.remove_data() diff --git a/ixmp4/core/optimization/equation.py b/ixmp4/core/optimization/equation.py index 58efbd18..09c49b1a 100644 --- a/ixmp4/core/optimization/equation.py +++ b/ixmp4/core/optimization/equation.py @@ -40,7 +40,7 @@ def data(self) -> dict[str, Any]: return self._model.data def add(self, data: dict[str, Any] | pd.DataFrame) -> None: - """Adds data to an existing Equation.""" + """Adds data to the Equation.""" self.backend.optimization.equations.add_data( equation_id=self._model.id, data=data ) @@ -49,7 +49,7 @@ def add(self, data: dict[str, Any] | pd.DataFrame) -> None: ) def remove_data(self) -> None: - """Removes data from an existing Equation.""" + """Removes all data from the Equation.""" self.backend.optimization.equations.remove_data(equation_id=self._model.id) self._model = self.backend.optimization.equations.get( run_id=self._model.run__id, name=self._model.name diff --git a/ixmp4/core/optimization/parameter.py b/ixmp4/core/optimization/parameter.py index 9d423632..ce8bf4ec 100644 --- a/ixmp4/core/optimization/parameter.py +++ b/ixmp4/core/optimization/parameter.py @@ -40,7 +40,7 @@ def data(self) -> dict[str, Any]: return self._model.data def add(self, data: dict[str, Any] | pd.DataFrame) -> None: - """Adds data to an existing Parameter.""" + """Adds data to the Parameter.""" self.backend.optimization.parameters.add_data( parameter_id=self._model.id, data=data ) diff --git a/ixmp4/core/optimization/table.py b/ixmp4/core/optimization/table.py index 3c536ec5..4869b51b 100644 --- a/ixmp4/core/optimization/table.py +++ b/ixmp4/core/optimization/table.py @@ -40,7 +40,7 @@ def data(self) -> dict[str, Any]: return self._model.data def add(self, data: dict[str, Any] | pd.DataFrame) -> None: - """Adds data to an existing Table.""" + """Adds data to the Table.""" self.backend.optimization.tables.add_data(table_id=self._model.id, data=data) self._model = self.backend.optimization.tables.get( run_id=self._model.run__id, name=self._model.name diff --git a/ixmp4/core/optimization/variable.py b/ixmp4/core/optimization/variable.py index c66b7792..0daf3359 100644 --- a/ixmp4/core/optimization/variable.py +++ b/ixmp4/core/optimization/variable.py @@ -40,7 +40,7 @@ def data(self) -> dict[str, Any]: return self._model.data def add(self, data: dict[str, Any] | pd.DataFrame) -> None: - """Adds data to an existing Variable.""" + """Adds data to the Variable.""" self.backend.optimization.variables.add_data( variable_id=self._model.id, data=data ) @@ -49,7 +49,7 @@ def add(self, data: dict[str, Any] | pd.DataFrame) -> None: ) def remove_data(self) -> None: - """Removes data from an existing Variable.""" + """Removes all data from the Variable.""" self.backend.optimization.variables.remove_data(variable_id=self._model.id) self._model = self.backend.optimization.variables.get( run_id=self._model.run__id, name=self._model.name diff --git a/ixmp4/data/db/optimization/equation/model.py b/ixmp4/data/db/optimization/equation/model.py index e7588043..5c589b63 100644 --- a/ixmp4/data/db/optimization/equation/model.py +++ b/ixmp4/data/db/optimization/equation/model.py @@ -26,16 +26,17 @@ class Equation(base.BaseModel): @validates("data") def validate_data(self, key: Any, data: dict[str, Any]) -> dict[str, Any]: - if data == {}: + if not bool(data): return data data_to_validate = copy.deepcopy(data) del data_to_validate["levels"] del data_to_validate["marginals"] - _ = utils.validate_data( - host=self, - data=data_to_validate, - columns=self.columns, - ) + if bool(data_to_validate): + _ = utils.validate_data( + host=self, + data=data_to_validate, + columns=self.columns, + ) return data __table_args__ = (db.UniqueConstraint("name", "run__id"),) diff --git a/ixmp4/data/db/optimization/equation/repository.py b/ixmp4/data/db/optimization/equation/repository.py index 34f9fdcd..393c5ae8 100644 --- a/ixmp4/data/db/optimization/equation/repository.py +++ b/ixmp4/data/db/optimization/equation/repository.py @@ -79,11 +79,7 @@ def _add_column( # type: ignore[no-untyped-def] **kwargs, ) - def add( - self, - run_id: int, - name: str, - ) -> Equation: + def add(self, run_id: int, name: str) -> Equation: equation = Equation(name=name, run__id=run_id) equation.set_creation_info(auth_context=self.backend.auth_context) self.session.add(equation) diff --git a/ixmp4/data/db/optimization/indexset/repository.py b/ixmp4/data/db/optimization/indexset/repository.py index c559ce23..69e5eac8 100644 --- a/ixmp4/data/db/optimization/indexset/repository.py +++ b/ixmp4/data/db/optimization/indexset/repository.py @@ -95,5 +95,4 @@ def add_data( Literal["float", "int", "str"], type(_data[0]).__name__ ) - self.session.add(indexset) self.session.commit() diff --git a/ixmp4/data/db/optimization/scalar/repository.py b/ixmp4/data/db/optimization/scalar/repository.py index 3ad02f51..e79af22d 100644 --- a/ixmp4/data/db/optimization/scalar/repository.py +++ b/ixmp4/data/db/optimization/scalar/repository.py @@ -89,8 +89,7 @@ def update( self.session.execute(exc) self.session.commit() - scalar: Scalar = self.get_by_id(id) - return scalar + return self.get_by_id(id) @guard("view") def list(self, **kwargs: Unpack[EnumerateKwargs]) -> Iterable[Scalar]: diff --git a/ixmp4/data/db/optimization/table/repository.py b/ixmp4/data/db/optimization/table/repository.py index 27dbef2c..dcdad5be 100644 --- a/ixmp4/data/db/optimization/table/repository.py +++ b/ixmp4/data/db/optimization/table/repository.py @@ -154,5 +154,4 @@ def add_data(self, table_id: int, data: dict[str, Any] | pd.DataFrame) -> None: orient="list" ) # type: ignore[assignment] - self.session.add(table) self.session.commit() diff --git a/ixmp4/data/db/optimization/utils.py b/ixmp4/data/db/optimization/utils.py index 3ff3bd8c..d612b4f3 100644 --- a/ixmp4/data/db/optimization/utils.py +++ b/ixmp4/data/db/optimization/utils.py @@ -10,13 +10,10 @@ def collect_indexsets_to_check( columns: list["Column"], -) -> dict[str, Any]: +) -> dict[str, list[float] | list[int] | list[str]]: """Creates a {key:value} dict from linked Column.names and their IndexSet.data.""" - collection: dict[str, Any] = {} - for column in columns: - collection[column.name] = column.indexset.data - return collection + return {column.name: column.indexset.data for column in columns} def validate_data( diff --git a/ixmp4/data/db/optimization/variable/model.py b/ixmp4/data/db/optimization/variable/model.py index 32cfae59..bae48e2e 100644 --- a/ixmp4/data/db/optimization/variable/model.py +++ b/ixmp4/data/db/optimization/variable/model.py @@ -27,16 +27,17 @@ class OptimizationVariable(base.BaseModel): @validates("data") def validate_data(self, key: Any, data: dict[str, Any]) -> dict[str, Any]: - if data == {}: + if not bool(data): return data data_to_validate = copy.deepcopy(data) del data_to_validate["levels"] del data_to_validate["marginals"] - _ = utils.validate_data( - host=self, - data=data_to_validate, - columns=self.columns, - ) + if bool(data_to_validate): + _ = utils.validate_data( + host=self, + data=data_to_validate, + columns=self.columns, + ) return data __table_args__ = (db.UniqueConstraint("name", "run__id"),) diff --git a/ixmp4/data/db/optimization/variable/repository.py b/ixmp4/data/db/optimization/variable/repository.py index 76aa13f3..c8bed49e 100644 --- a/ixmp4/data/db/optimization/variable/repository.py +++ b/ixmp4/data/db/optimization/variable/repository.py @@ -178,11 +178,14 @@ def add_data(self, variable_id: int, data: dict[str, Any] | pd.DataFrame) -> Non index_list = [column.name for column in variable.columns] existing_data = pd.DataFrame(variable.data) - if not existing_data.empty: - existing_data.set_index(index_list, inplace=True) - variable.data = ( - data.set_index(index_list).combine_first(existing_data).reset_index() - ).to_dict(orient="list") # type: ignore[assignment] + if index_list: + data = data.set_index(index_list) + if not existing_data.empty: + existing_data.set_index(index_list, inplace=True) + data = data.combine_first(existing_data) + if index_list: + data = data.reset_index() + variable.data = data.to_dict(orient="list") # type: ignore[assignment] self.session.commit() diff --git a/tests/core/test_optimization_variable.py b/tests/core/test_optimization_variable.py index 183e65f5..0dec1d45 100644 --- a/tests/core/test_optimization_variable.py +++ b/tests/core/test_optimization_variable.py @@ -273,6 +273,14 @@ def test_variable_add_data(self, platform: ixmp4.Platform) -> None: ) assert_unordered_equality(expected, pd.DataFrame(variable_4.data)) + # Test adding to scalar variable raises + with pytest.raises( + OptimizationDataValidationError, + match="Trying to add data to unknown Columns!", + ): + variable_5 = run.optimization.variables.create("Variable 5") + variable_5.add(data={"foo": ["bar"], "levels": [1], "marginals": [0]}) + def test_variable_remove_data(self, platform: ixmp4.Platform) -> None: run = platform.runs.create("Model", "Scenario") indexset = run.optimization.indexsets.create("Indexset") diff --git a/tests/core/test_run.py b/tests/core/test_run.py index e75aa4f5..1a128555 100644 --- a/tests/core/test_run.py +++ b/tests/core/test_run.py @@ -181,3 +181,29 @@ def delete_all_datapoints(self, run: ixmp4.Run) -> None: run.iamc.remove(cat, type=ixmp4.DataPoint.Type.CATEGORICAL) if not datetime.empty: run.iamc.remove(datetime, type=ixmp4.DataPoint.Type.DATETIME) + + def test_run_remove_solution(self, platform: ixmp4.Platform) -> None: + run = platform.runs.create("Model", "Scenario") + indexset = run.optimization.indexsets.create("Indexset") + indexset.add(["foo", "bar"]) + test_data = { + "Indexset": ["bar", "foo"], + "levels": [2.0, 1], + "marginals": [0, "test"], + } + run.optimization.equations.create( + "Equation", + constrained_to_indexsets=[indexset.name], + ).add(test_data) + run.optimization.variables.create( + "Variable", + constrained_to_indexsets=[indexset.name], + ).add(test_data) + + run.optimization.remove_solution() + # Need to fetch them here even if fetched before because API layer might not + # forward changes automatically + equation = run.optimization.equations.get("Equation") + variable = run.optimization.variables.get("Variable") + assert equation.data == {} + assert variable.data == {} diff --git a/tests/data/test_optimization_variable.py b/tests/data/test_optimization_variable.py index f7969134..92bbb4c9 100644 --- a/tests/data/test_optimization_variable.py +++ b/tests/data/test_optimization_variable.py @@ -83,6 +83,18 @@ def test_create_variable(self, platform: ixmp4.Platform) -> None: column_names=["Dimension 1"], ) + # Test that giving column_names, but not constrained_to_indexsets raises + with pytest.raises( + OptimizationItemUsageError, + match="Received `column_names` to name columns, but no " + "`constrained_to_indexsets`", + ): + _ = platform.backend.optimization.variables.create( + run_id=run.id, + name="Variable 0", + column_names=["Dimension 1"], + ) + # Test mismatch in constrained_to_indexsets and column_names raises with pytest.raises(OptimizationItemUsageError, match="not equal in length"): _ = platform.backend.optimization.variables.create(