Skip to content

Commit

Permalink
refactor(local): hide local implementation from constraints model (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
salemsd authored Dec 11, 2024
1 parent f24a078 commit 06bcc16
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 95 deletions.
60 changes: 0 additions & 60 deletions src/antares/model/binding_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,46 +120,6 @@ class BindingConstraintProperties(DefaultBindingConstraintProperties):
pass


class BindingConstraintPropertiesLocal(DefaultBindingConstraintProperties):
"""
Used to create the entries for the bindingconstraints.ini file
Attributes:
constraint_name: The constraint name
constraint_id: The constraint id
properties (BindingConstraintProperties): The BindingConstraintProperties to set
terms (dict[str, ConstraintTerm]]): The terms applying to the binding constraint
"""

constraint_name: str
constraint_id: str
terms: dict[str, ConstraintTerm] = {}

@property
def list_ini_fields(self) -> dict[str, str]:
ini_dict = {
"name": self.constraint_name,
"id": self.constraint_id,
"enabled": f"{self.enabled}".lower(),
"type": self.time_step.value,
"operator": self.operator.value,
"comments": self.comments,
"filter-year-by-year": self.filter_year_by_year,
"filter-synthesis": self.filter_synthesis,
"group": self.group,
} | {term_id: term.weight_offset() for term_id, term in self.terms.items()}
return {key: value for key, value in ini_dict.items() if value not in [None, ""]}

def yield_binding_constraint_properties(self) -> BindingConstraintProperties:
excludes = {
"constraint_name",
"constraint_id",
"terms",
"list_ini_fields",
}
return BindingConstraintProperties.model_validate(self.model_dump(mode="json", exclude=excludes))


class BindingConstraint:
def __init__( # type: ignore # TODO: Find a way to avoid circular imports
self,
Expand All @@ -173,9 +133,6 @@ def __init__( # type: ignore # TODO: Find a way to avoid circular imports
self._id = transform_name_to_id(name)
self._properties = properties or BindingConstraintProperties()
self._terms = {term.id: term for term in terms} if terms else {}
self._local_properties = BindingConstraintPropertiesLocal.model_validate(
self._create_local_property_args(self._properties)
)

@property
def name(self) -> str:
Expand All @@ -191,25 +148,8 @@ def properties(self) -> BindingConstraintProperties:

@properties.setter
def properties(self, new_properties: BindingConstraintProperties) -> None:
self._local_properties = BindingConstraintPropertiesLocal.model_validate(
self._create_local_property_args(new_properties)
)
self._properties = new_properties

def _create_local_property_args(
self, properties: BindingConstraintProperties
) -> dict[str, Union[str, dict[str, ConstraintTerm]]]:
return {
"constraint_name": self._name,
"constraint_id": self._id,
"terms": self._terms,
**properties.model_dump(mode="json", exclude_none=True),
}

@property
def local_properties(self) -> BindingConstraintPropertiesLocal:
return self._local_properties

def get_terms(self) -> Dict[str, ConstraintTerm]:
return self._terms

Expand Down
84 changes: 66 additions & 18 deletions src/antares/service/local_services/binding_constraint_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,70 @@
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.
from typing import Any, Optional
from typing import Any, Optional, Union

import numpy as np
import pandas as pd

from pydantic import Field

from antares.config.local_configuration import LocalConfiguration
from antares.exceptions.exceptions import BindingConstraintCreationError
from antares.model.binding_constraint import (
BindingConstraint,
BindingConstraintFrequency,
BindingConstraintOperator,
BindingConstraintProperties,
BindingConstraintPropertiesLocal,
ConstraintMatrixName,
ConstraintTerm,
DefaultBindingConstraintProperties,
)
from antares.service.base_services import BaseBindingConstraintService
from antares.tools.ini_tool import IniFile, IniFileTypes
from antares.tools.matrix_tool import df_save
from antares.tools.time_series_tool import TimeSeriesFileType


class BindingConstraintPropertiesLocal(DefaultBindingConstraintProperties):
"""
Used to create the entries for the bindingconstraints.ini file
Attributes:
constraint_name: The constraint name
constraint_id: The constraint id
properties (BindingConstraintProperties): The BindingConstraintProperties to set
terms (dict[str, ConstraintTerm]]): The terms applying to the binding constraint
"""

constraint_name: str
constraint_id: str
terms: dict[str, ConstraintTerm] = Field(default_factory=dict[str, ConstraintTerm])

@property
def list_ini_fields(self) -> dict[str, str]:
ini_dict = {
"name": self.constraint_name,
"id": self.constraint_id,
"enabled": f"{self.enabled}".lower(),
"type": self.time_step.value,
"operator": self.operator.value,
"comments": self.comments,
"filter-year-by-year": self.filter_year_by_year,
"filter-synthesis": self.filter_synthesis,
"group": self.group,
} | {term_id: term.weight_offset() for term_id, term in self.terms.items()}
return {key: value for key, value in ini_dict.items() if value not in [None, ""]}

def yield_binding_constraint_properties(self) -> BindingConstraintProperties:
excludes = {
"constraint_name",
"constraint_id",
"terms",
"list_ini_fields",
}
return BindingConstraintProperties(**self.model_dump(mode="json", exclude=excludes))


class BindingConstraintLocalService(BaseBindingConstraintService):
def __init__(self, config: LocalConfiguration, study_name: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand All @@ -53,20 +95,34 @@ def create_binding_constraint(
properties=properties,
terms=terms,
)
constraint.properties = constraint.local_properties.yield_binding_constraint_properties()

local_properties = self._generate_local_properties(constraint)
constraint.properties = local_properties.yield_binding_constraint_properties()

current_ini_content = self.ini_file.ini_dict_binding_constraints or {}
if any(values.get("id") == constraint.id for values in current_ini_content.values()):
raise BindingConstraintCreationError(
constraint_name=name, message=f"A binding constraint with the name {name} already exists."
)

self._write_binding_constraint_ini(constraint.properties, name, name, terms)
self._write_binding_constraint_ini(local_properties, name, name, terms)

self._store_time_series(constraint, less_term_matrix, equal_term_matrix, greater_term_matrix)

return constraint

@staticmethod
def _create_local_property_args(constraint: BindingConstraint) -> dict[str, Union[str, dict[str, ConstraintTerm]]]:
return {
"constraint_name": constraint.name,
"constraint_id": constraint.id,
"terms": constraint.get_terms(),
**constraint.properties.model_dump(mode="json", exclude_none=True),
}

def _generate_local_properties(self, constraint: BindingConstraint) -> BindingConstraintPropertiesLocal:
return BindingConstraintPropertiesLocal.model_validate(self._create_local_property_args(constraint))

def _store_time_series(
self,
constraint: BindingConstraint,
Expand Down Expand Up @@ -103,7 +159,7 @@ def _check_if_empty_ts(time_step: BindingConstraintFrequency, time_series: Optio

def _write_binding_constraint_ini(
self,
properties: BindingConstraintProperties,
local_properties: BindingConstraintPropertiesLocal,
constraint_name: str,
constraint_id: str,
terms: Optional[list[ConstraintTerm]] = None,
Expand Down Expand Up @@ -131,17 +187,8 @@ def _write_binding_constraint_ini(
# Persist the updated INI content
self.ini_file.write_ini_file()
else:
terms_dict = {term.id: term for term in terms} if terms else {}

full_properties = BindingConstraintPropertiesLocal(
constraint_name=constraint_name,
constraint_id=constraint_id,
terms=terms_dict,
**properties.model_dump(),
)

section_index = len(current_ini_content)
current_ini_content[str(section_index)] = full_properties.list_ini_fields
current_ini_content[str(section_index)] = local_properties.list_ini_fields

self.ini_file.ini_dict_binding_constraints = current_ini_content
self.ini_file.write_ini_file()
Expand All @@ -158,7 +205,7 @@ def add_constraint_terms(self, constraint: BindingConstraint, terms: list[Constr
list[ConstraintTerm]: The updated list of terms.
"""

new_terms = constraint.local_properties.terms.copy()
new_terms = constraint.get_terms().copy()

for term in terms:
if term.id in constraint.get_terms():
Expand All @@ -167,12 +214,13 @@ def add_constraint_terms(self, constraint: BindingConstraint, terms: list[Constr
)
new_terms[term.id] = term

constraint.local_properties.terms = new_terms
local_properties = self._generate_local_properties(constraint)
local_properties.terms = new_terms

terms_values = list(new_terms.values())

self._write_binding_constraint_ini(
properties=constraint.properties,
local_properties=local_properties,
constraint_name=constraint.name,
constraint_id=constraint.id,
terms=terms_values,
Expand Down
17 changes: 0 additions & 17 deletions tests/antares/services/local_services/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
BindingConstraintFrequency,
BindingConstraintOperator,
BindingConstraintProperties,
BindingConstraintPropertiesLocal,
ConstraintTerm,
)
from antares.model.commons import FilterOption
Expand Down Expand Up @@ -2229,19 +2228,3 @@ def test_submitted_time_series_is_saved(self, local_study):

# Then
assert actual_time_series.equals(expected_time_series)

def test_updating_binding_constraint_properties_updates_local(self, local_study_with_constraint, test_constraint):
# Given
new_properties = BindingConstraintProperties(comments="testing update")
local_property_args = {
"constraint_name": test_constraint.name,
"constraint_id": test_constraint.id,
"terms": test_constraint._terms,
**new_properties.model_dump(mode="json", exclude_none=True),
}

# When
test_constraint.properties = new_properties

# Then
assert test_constraint.local_properties == BindingConstraintPropertiesLocal.model_validate(local_property_args)

0 comments on commit 06bcc16

Please sign in to comment.