Skip to content

Commit

Permalink
Fix some type checking issues, remove useless code
Browse files Browse the repository at this point in the history
  • Loading branch information
tbittar committed Aug 21, 2024
1 parent a904d9d commit 8b7a244
Show file tree
Hide file tree
Showing 14 changed files with 28 additions and 1,181 deletions.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ types-PyYAML~=6.0.12.12
antlr4-tools~=0.2.1
pandas~=2.0.3
pandas-stubs<=2.0.3
types-PyYAML~=6.0.12
1 change: 0 additions & 1 deletion src/andromede/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# This file is part of the Antares project.

from .copy import CopyVisitor, copy_expression
from .degree import ExpressionDegreeVisitor, compute_degree
from .evaluate_parameters_efficient import ValueProvider
from .expression_efficient import (
AdditionNode,
Expand Down
120 changes: 0 additions & 120 deletions src/andromede/expression/degree.py

This file was deleted.

76 changes: 0 additions & 76 deletions src/andromede/expression/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,7 @@
from dataclasses import dataclass, field
from typing import Dict

from .expression_efficient import (
ComparisonNode,
ComponentParameterNode,
ExpressionNodeEfficient,
LiteralNode,
ParameterNode,
PortFieldAggregatorNode,
PortFieldNode,
ScenarioOperatorNode,
TimeAggregatorNode,
TimeOperatorNode,
)
from .value_provider import TimeScenarioIndex, TimeScenarioIndices, ValueProvider
from .visitor import ExpressionVisitorOperations, visit


# Used only for tests
Expand Down Expand Up @@ -70,66 +57,3 @@ def block_length() -> int:
@staticmethod
def scenarios() -> int:
raise NotImplementedError()


@dataclass(frozen=True)
class EvaluationVisitor(ExpressionVisitorOperations[float]):
"""
Evaluates the expression with respect to the provided context
(variables and parameters values).
"""

context: ValueProvider

def literal(self, node: LiteralNode) -> float:
return node.value

def comparison(self, node: ComparisonNode) -> float:
raise ValueError("Cannot evaluate comparison operator.")

def parameter(self, node: ParameterNode) -> float:
return self.context.get_parameter_value(node.name)

def comp_parameter(self, node: ComponentParameterNode) -> float:
return self.context.get_component_parameter_value(node.component_id, node.name)

def time_operator(self, node: TimeOperatorNode) -> float:
raise NotImplementedError()

def time_aggregator(self, node: TimeAggregatorNode) -> float:
raise NotImplementedError()

def scenario_operator(self, node: ScenarioOperatorNode) -> float:
raise NotImplementedError()

def port_field(self, node: PortFieldNode) -> float:
raise NotImplementedError()

def port_field_aggregator(self, node: PortFieldAggregatorNode) -> float:
raise NotImplementedError()


def evaluate(
expression: ExpressionNodeEfficient, value_provider: ValueProvider
) -> float:
return visit(expression, EvaluationVisitor(value_provider))


@dataclass(frozen=True)
class InstancesIndexVisitor(EvaluationVisitor):
"""
Evaluates an expression given as instances index which should have no variable and constant parameter values.
"""

def parameter(self, node: ParameterNode) -> float:
if not self.context.parameter_is_constant_over_time(node.name):
raise ValueError(
"Parameter given in an instance index expression must be constant over time"
)
return self.context.get_parameter_value(node.name)

def time_operator(self, node: TimeOperatorNode) -> float:
raise ValueError("An instance index expression cannot contain time operator")

def time_aggregator(self, node: TimeAggregatorNode) -> float:
raise ValueError("An instance index expression cannot contain time aggregator")
4 changes: 3 additions & 1 deletion src/andromede/expression/expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def expec(self) -> "ExpressionNodeEfficient":

def variance(self) -> "ExpressionNodeEfficient":
return _apply_if_node(
self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.Variance)
self, lambda x: ScenarioOperatorNode(x, ScenarioOperatorName.VARIANCE)
)


Expand All @@ -141,6 +141,8 @@ def wrap_in_node(obj: Any) -> ExpressionNodeEfficient:
return obj
elif isinstance(obj, float) or isinstance(obj, int):
return LiteralNode(float(obj))
# else:
# return None
# Do not raise excpetion so that we can return NotImplemented in _apply_if_node
# raise TypeError(f"Unable to wrap {obj} into an expression node")

Expand Down
36 changes: 12 additions & 24 deletions src/andromede/expression/linear_expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
TypeVar,
Expand All @@ -31,10 +33,7 @@

from .context_adder import add_component_context
from .equality import expressions_equal
from .evaluate_parameters_efficient import (
check_resolved_expr,
resolve_coefficient,
)
from .evaluate_parameters_efficient import check_resolved_expr, resolve_coefficient
from .expression_efficient import (
ExpressionNodeEfficient,
ExpressionRange,
Expand Down Expand Up @@ -65,11 +64,7 @@
TimeShift,
TimeSum,
)
from .value_provider import (
TimeScenarioIndex,
TimeScenarioIndices,
ValueProvider,
)
from .value_provider import TimeScenarioIndex, TimeScenarioIndices, ValueProvider


@dataclass(frozen=True)
Expand Down Expand Up @@ -368,7 +363,7 @@ def __str__(self) -> str:
result += f".{str(self.aggregator)}"
return result

def sum_connections(self) -> "LinearExpressionEfficient":
def sum_connections(self) -> "PortFieldTerm":
if self.aggregator is not None:
raise ValueError(f"Port field {str(self)} already has a port aggregator")
return dataclasses.replace(self, aggregator=PortSum())
Expand All @@ -377,6 +372,10 @@ def sum_connections(self) -> "LinearExpressionEfficient":
T_val = TypeVar("T_val", bound=Union[TermEfficient, PortFieldTerm])


def _get_neutral_term(term: T_val, neutral: float) -> T_val:
return dataclasses.replace(term, coefficient=wrap_in_node(neutral))


@overload
def _merge_dicts(
lhs: Dict[TermKeyEfficient, TermEfficient],
Expand All @@ -397,10 +396,6 @@ def _merge_dicts(
...


def _get_neutral_term(term: T_val, neutral: float) -> T_val:
return dataclasses.replace(term, coefficient=neutral)


def _merge_dicts(lhs, rhs, merge_func, neutral):
res = {}
for k, v in lhs.items():
Expand Down Expand Up @@ -821,7 +816,7 @@ def sum(

def _apply_operator(
self,
sum_args: Dict[
sum_args: Mapping[
str,
Union[
int,
Expand All @@ -839,13 +834,6 @@ def _apply_operator(

return result_terms

# def sum_connections(self) -> "ExpressionNode":
# if isinstance(self, PortFieldNode):
# return PortFieldAggregatorNode(self, aggregator=PortFieldAggregatorName.PORT_SUM)
# raise ValueError(
# f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}."
# )

def shift(
self,
expressions: Union[
Expand Down Expand Up @@ -1036,7 +1024,7 @@ def linear_expressions_equal_if_present(
# TODO: Is this function useful ? Could we just rely on the sum operator overloading ? Only the case with an empty list may make the function useful
def sum_expressions(
expressions: Sequence[LinearExpressionEfficient],
) -> LinearExpressionEfficient:
) -> Union[LinearExpressionEfficient, Literal[0]]:
if len(expressions) == 0:
return wrap_in_linear_expr(literal(0))
else:
Expand All @@ -1059,7 +1047,7 @@ def __post_init__(
for bound in [self.lower_bound, self.upper_bound]:
if not bound.is_constant():
raise ValueError(
f"The bounds of a constraint should not contain variables, {print_expr(bound)} was given."
f"The bounds of a constraint should not contain variables, {str(bound)} was given."
)

def __str__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/andromede/expression/port_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ class PortAggregator:

@dataclass(frozen=True)
class PortSum(PortAggregator):
def __str__(self):
def __str__(self) -> str:
return "PortSum"
5 changes: 1 addition & 4 deletions src/andromede/expression/time_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ class TimeOperator(ABC):
def rolling(cls) -> bool:
raise NotImplementedError

def key(self) -> Tuple[int, ...]:
def key(self) -> InstancesTimeIndex:
return self.time_ids

def size(self) -> int:
return len(self.time_ids.expressions)


@dataclass(frozen=True)
class TimeShift(TimeOperator):
Expand Down
Loading

0 comments on commit 8b7a244

Please sign in to comment.