diff --git a/src/andromede/expression/__init__.py b/src/andromede/expression/__init__.py index 55b51967..70c8c8ac 100644 --- a/src/andromede/expression/__init__.py +++ b/src/andromede/expression/__init__.py @@ -12,41 +12,28 @@ from .copy import CopyVisitor, copy_expression from .degree import ExpressionDegreeVisitor, compute_degree -from .evaluate import EvaluationContext, EvaluationVisitor, ValueProvider, evaluate -from .evaluate_parameters import ( - ParameterResolver, - ParameterValueProvider, - resolve_parameters, -) - -from .expression import ( - # AdditionNode, - # Comparator, - # ComparisonNode, - # DivisionNode, - ExpressionNode, - # LiteralNode, - # MultiplicationNode, - # NegationNode, - # ParameterNode, - # SubstractionNode, - VariableNode, - literal, - param, - sum_expressions, - var, -) +from .evaluate_parameters_efficient import ValueProvider from .expression_efficient import ( AdditionNode, - Comparator, ComparisonNode, + ComponentParameterNode, DivisionNode, ExpressionNodeEfficient, + ExpressionRange, + InstancesTimeIndex, LiteralNode, MultiplicationNode, NegationNode, ParameterNode, + PortFieldAggregatorNode, + PortFieldNode, + ScenarioOperatorName, + ScenarioOperatorNode, SubstractionNode, + TimeAggregatorName, + TimeAggregatorNode, + TimeOperatorName, + TimeOperatorNode, ) from .print import PrinterVisitor, print_expr from .visitor import ExpressionVisitor, visit diff --git a/src/andromede/expression/evaluate.py b/src/andromede/expression/evaluate.py index 08477070..e09f033c 100644 --- a/src/andromede/expression/evaluate.py +++ b/src/andromede/expression/evaluate.py @@ -13,7 +13,6 @@ from dataclasses import dataclass, field from typing import Dict -from andromede.expression.expression import VariableNode from andromede.expression.expression_efficient import ( ComparisonNode, ComponentParameterNode, @@ -93,18 +92,12 @@ def literal(self, node: LiteralNode) -> float: def comparison(self, node: ComparisonNode) -> float: raise ValueError("Cannot evaluate comparison operator.") - def variable(self, node: VariableNode) -> float: - return self.context.get_variable_value(node.name) - def parameter(self, node: ParameterNode) -> float: return self.context.get_parameter_value(node.name) def comp_parameter(self, node: ComponentParameterNode) -> float: return self.context.get_component_parameter_value(node.component_id, node.name) - # def comp_variable(self, node: ComponentVariableNode) -> float: - # return self.context.get_component_variable_value(node.component_id, node.name) - def time_operator(self, node: TimeOperatorNode) -> float: raise NotImplementedError() @@ -133,9 +126,6 @@ class InstancesIndexVisitor(EvaluationVisitor): Evaluates an expression given as instances index which should have no variable and constant parameter values. """ - def variable(self, node: VariableNode) -> float: - raise ValueError("An instance index expression cannot contain variable") - def parameter(self, node: ParameterNode) -> float: if not self.context.parameter_is_constant_over_time(node.name): raise ValueError( diff --git a/src/andromede/expression/evaluate_parameters.py b/src/andromede/expression/evaluate_parameters.py deleted file mode 100644 index 7c734260..00000000 --- a/src/andromede/expression/evaluate_parameters.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import List - -from andromede.expression.evaluate import InstancesIndexVisitor, ValueProvider - -from .copy import CopyVisitor -from .expression import ( - ComponentParameterNode, - ExpressionNode, - ExpressionRange, - InstancesTimeIndex, - LiteralNode, - ParameterNode, -) -from .visitor import visit - - -class ParameterValueProvider(ABC): - @abstractmethod - def get_parameter_value(self, name: str) -> float: - ... - - @abstractmethod - def get_component_parameter_value(self, component_id: str, name: str) -> float: - ... - - -@dataclass(frozen=True) -class ParameterResolver(CopyVisitor): - """ - Duplicates the AST with replacement of parameter nodes by literal nodes. - """ - - context: ParameterValueProvider - - def parameter(self, node: ParameterNode) -> ExpressionNode: - value: float = self.context.get_parameter_value(node.name) - return LiteralNode(value) - - def comp_parameter(self, node: ComponentParameterNode) -> ExpressionNode: - value: float = self.context.get_component_parameter_value( - node.component_id, node.name - ) - return LiteralNode(value) - - -def resolve_parameters( - expression: ExpressionNode, parameter_provider: ParameterValueProvider -) -> ExpressionNode: - return visit(expression, ParameterResolver(parameter_provider)) - - -def float_to_int(value: float) -> int: - if isinstance(value, int) or value.is_integer(): - return int(value) - else: - raise ValueError(f"{value} is not an integer.") - - -def evaluate_time_id(expr: ExpressionNode, value_provider: ValueProvider) -> int: - float_time_id = visit(expr, InstancesIndexVisitor(value_provider)) - try: - time_id = float_to_int(float_time_id) - except ValueError: - print(f"{expr} does not represent an integer time index.") - return time_id - - -def get_time_ids_from_instances_index( - instances_index: InstancesTimeIndex, value_provider: ValueProvider -) -> List[int]: - time_ids = [] - if isinstance(instances_index.expressions, list): # List[ExpressionNode] - for expr in instances_index.expressions: - time_ids.append(evaluate_time_id(expr, value_provider)) - - elif isinstance(instances_index.expressions, ExpressionRange): # ExpressionRange - start_id = evaluate_time_id(instances_index.expressions.start, value_provider) - stop_id = evaluate_time_id(instances_index.expressions.stop, value_provider) - step_id = 1 - if instances_index.expressions.step is not None: - step_id = evaluate_time_id(instances_index.expressions.step, value_provider) - # ExpressionRange includes stop_id whereas range excludes it - time_ids = list(range(start_id, stop_id + 1, step_id)) - - return time_ids diff --git a/src/andromede/expression/expression.py b/src/andromede/expression/expression.py deleted file mode 100644 index 01e8136b..00000000 --- a/src/andromede/expression/expression.py +++ /dev/null @@ -1,454 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -""" -Defines the model for generic expressions. -""" -import enum -import inspect -from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Union - -import andromede.expression.port_operator -import andromede.expression.scenario_operator -import andromede.expression.time_operator - - -class Instances(enum.Enum): - SIMPLE = "SIMPLE" - MULTIPLE = "MULTIPLE" - - -@dataclass(frozen=True) -class ExpressionNode: - """ - Base class for all nodes of the expression AST. - - Operators overloading is provided to help create expressions - programmatically. - - Examples - >>> expr = -var('x') + 5 / param('p') - """ - - instances: Instances = field(init=False, default=Instances.SIMPLE) - - def __neg__(self) -> "ExpressionNode": - return NegationNode(self) - - def __add__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: AdditionNode(self, x)) - - def __radd__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: AdditionNode(x, self)) - - def __sub__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: SubstractionNode(self, x)) - - def __rsub__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: SubstractionNode(x, self)) - - def __mul__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: MultiplicationNode(self, x)) - - def __rmul__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: MultiplicationNode(x, self)) - - def __truediv__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node(rhs, lambda x: DivisionNode(self, x)) - - def __rtruediv__(self, lhs: Any) -> "ExpressionNode": - return _apply_if_node(lhs, lambda x: DivisionNode(x, self)) - - def __le__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node( - rhs, lambda x: ComparisonNode(self, x, Comparator.LESS_THAN) - ) - - def __ge__(self, rhs: Any) -> "ExpressionNode": - return _apply_if_node( - rhs, lambda x: ComparisonNode(self, x, Comparator.GREATER_THAN) - ) - - def __eq__(self, rhs: Any) -> "ExpressionNode": # type: ignore - return _apply_if_node(rhs, lambda x: ComparisonNode(self, x, Comparator.EQUAL)) - - def sum(self) -> "ExpressionNode": - if isinstance(self, TimeOperatorNode): - return TimeAggregatorNode(self, "TimeSum", stay_roll=True) - else: - return _apply_if_node( - self, lambda x: TimeAggregatorNode(x, "TimeSum", stay_roll=False) - ) - - def sum_connections(self) -> "ExpressionNode": - if isinstance(self, PortFieldNode): - return PortFieldAggregatorNode(self, aggregator="PortSum") - raise ValueError( - f"sum_connections() applies only for PortFieldNode, whereas the current node is of type {type(self)}." - ) - - def shift( - self, - expressions: Union[ - int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange" - ], - ) -> "ExpressionNode": - return _apply_if_node( - self, - lambda x: TimeOperatorNode(x, "TimeShift", InstancesTimeIndex(expressions)), - ) - - def eval( - self, - expressions: Union[ - int, "ExpressionNode", List["ExpressionNode"], "ExpressionRange" - ], - ) -> "ExpressionNode": - return _apply_if_node( - self, - lambda x: TimeOperatorNode( - x, "TimeEvaluation", InstancesTimeIndex(expressions) - ), - ) - - def expec(self) -> "ExpressionNode": - return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Expectation")) - - def variance(self) -> "ExpressionNode": - return _apply_if_node(self, lambda x: ScenarioOperatorNode(x, "Variance")) - - -def _wrap_in_node(obj: Any) -> ExpressionNode: - if isinstance(obj, ExpressionNode): - return obj - elif isinstance(obj, float) or isinstance(obj, int): - return LiteralNode(float(obj)) - raise TypeError(f"Unable to wrap {obj} into an expression node") - - -def _apply_if_node( - obj: Any, func: Callable[["ExpressionNode"], "ExpressionNode"] -) -> "ExpressionNode": - if as_node := _wrap_in_node(obj): - return func(as_node) - else: - return NotImplemented - - -@dataclass(frozen=True, eq=False) -class VariableNode(ExpressionNode): - name: str - - -def var(name: str) -> VariableNode: - return VariableNode(name) - - -@dataclass(frozen=True, eq=False) -class PortFieldNode(ExpressionNode): - """ - References a port field. - """ - - port_name: str - field_name: str - - -def port_field(port_name: str, field_name: str) -> PortFieldNode: - return PortFieldNode(port_name, field_name) - - -@dataclass(frozen=True, eq=False) -class ParameterNode(ExpressionNode): - name: str - - -def param(name: str) -> ParameterNode: - return ParameterNode(name) - - -@dataclass(frozen=True, eq=False) -class ComponentParameterNode(ExpressionNode): - """ - Represents one parameter of one component. - - When building actual equations for a system, - we need to associated each parameter to its - actual component, at some point. - """ - - component_id: str - name: str - - -def comp_param(component_id: str, name: str) -> ComponentParameterNode: - return ComponentParameterNode(component_id, name) - - -@dataclass(frozen=True, eq=False) -class ComponentVariableNode(ExpressionNode): - """ - Represents one variable of one component. - - When building actual equations for a system, - we need to associated each variable to its - actual component, at some point. - """ - - component_id: str - name: str - - -def comp_var(component_id: str, name: str) -> ComponentVariableNode: - return ComponentVariableNode(component_id, name) - - -@dataclass(frozen=True, eq=False) -class LiteralNode(ExpressionNode): - value: float - - -def literal(value: float) -> LiteralNode: - return LiteralNode(value) - - -@dataclass(frozen=True, eq=False) -class UnaryOperatorNode(ExpressionNode): - operand: ExpressionNode - - def __post_init__(self) -> None: - object.__setattr__(self, "instances", self.operand.instances) - - -@dataclass(frozen=True, eq=False) -class PortFieldAggregatorNode(UnaryOperatorNode): - aggregator: str - - def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.port_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.port_operator.PortAggregator) - ] - if self.aggregator not in valid_names: - raise NotImplementedError( - f"{self.aggregator} is not a valid port aggregator, valid port aggregators are {valid_names}" - ) - - -@dataclass(frozen=True, eq=False) -class NegationNode(UnaryOperatorNode): - pass - - -@dataclass(frozen=True, eq=False) -class BinaryOperatorNode(ExpressionNode): - left: ExpressionNode - right: ExpressionNode - - def __post_init__(self) -> None: - binary_operator_post_init(self, "apply binary operation with") - - -def binary_operator_post_init(node: BinaryOperatorNode, operation: str) -> None: - if node.left.instances != node.right.instances: - raise ValueError( - f"Cannot {operation} {node.left} and {node.right} as they do not have the same number of instances." - ) - else: - object.__setattr__(node, "instances", node.left.instances) - - -class Comparator(enum.Enum): - LESS_THAN = "LESS_THAN" - EQUAL = "EQUAL" - GREATER_THAN = "GREATER_THAN" - - -@dataclass(frozen=True, eq=False) -class ComparisonNode(BinaryOperatorNode): - comparator: Comparator - - def __post_init__(self) -> None: - binary_operator_post_init(self, "compare") - - -@dataclass(frozen=True, eq=False) -class AdditionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "add") - - -@dataclass(frozen=True, eq=False) -class SubstractionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "substract") - - -@dataclass(frozen=True, eq=False) -class MultiplicationNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "multiply") - - -@dataclass(frozen=True, eq=False) -class DivisionNode(BinaryOperatorNode): - def __post_init__(self) -> None: - binary_operator_post_init(self, "divide") - - -@dataclass(frozen=True, eq=False) -class ExpressionRange: - start: ExpressionNode - stop: ExpressionNode - step: Optional[ExpressionNode] = None - - def __post_init__(self) -> None: - for attribute in self.__dict__: - value = getattr(self, attribute) - object.__setattr__( - self, attribute, _wrap_in_node(value) if value is not None else value - ) - - -IntOrExpr = Union[int, ExpressionNode] - - -def expression_range( - start: IntOrExpr, stop: IntOrExpr, step: Optional[IntOrExpr] = None -) -> ExpressionRange: - return ExpressionRange( - start=_wrap_in_node(start), - stop=_wrap_in_node(stop), - step=None if step is None else _wrap_in_node(step), - ) - - -@dataclass -class InstancesTimeIndex: - """ - Defines a set of time indices on which a time operator operates. - - In particular, it defines time indices created by the shift operator. - - The actual indices can either be defined as a time range defined by - 2 expression, or as a list of expressions. - """ - - expressions: Union[List[ExpressionNode], ExpressionRange] - - def __init__( - self, - expressions: Union[int, ExpressionNode, List[ExpressionNode], ExpressionRange], - ) -> None: - if not isinstance(expressions, (int, ExpressionNode, list, ExpressionRange)): - raise TypeError( - f"{expressions} must be of type among {{int, ExpressionNode, List[ExpressionNode], ExpressionRange}}" - ) - if isinstance(expressions, list) and not all( - isinstance(x, ExpressionNode) for x in expressions - ): - raise TypeError( - f"All elements of {expressions} must be of type ExpressionNode" - ) - - if isinstance(expressions, (int, ExpressionNode)): - self.expressions = [_wrap_in_node(expressions)] - else: - self.expressions = expressions - - def is_simple(self) -> bool: - if isinstance(self.expressions, list): - return len(self.expressions) == 1 - else: - # TODO: We could also check that if a range only includes literal nodes, compute the length of the range, if it's one return True. This is more complicated, I do not know if we want to do this - return False - - -@dataclass(frozen=True, eq=False) -class TimeOperatorNode(UnaryOperatorNode): - name: str - instances_index: InstancesTimeIndex - - def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.time_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.time_operator.TimeOperator) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" - ) - if self.operand.instances == Instances.SIMPLE: - if self.instances_index.is_simple(): - object.__setattr__(self, "instances", Instances.SIMPLE) - else: - object.__setattr__(self, "instances", Instances.MULTIPLE) - else: - raise ValueError( - "Cannot apply time operator on an expression that already represents multiple instances" - ) - - -@dataclass(frozen=True, eq=False) -class TimeAggregatorNode(UnaryOperatorNode): - name: str - stay_roll: bool - - def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.time_operator, inspect.isclass - ) - if issubclass(cls, andromede.expression.time_operator.TimeAggregator) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid time aggregator, valid time aggregators are {valid_names}" - ) - object.__setattr__(self, "instances", Instances.SIMPLE) - - -@dataclass(frozen=True, eq=False) -class ScenarioOperatorNode(UnaryOperatorNode): - name: str - - def __post_init__(self) -> None: - valid_names = [ - cls.__name__ - for _, cls in inspect.getmembers( - andromede.expression.scenario_operator, inspect.isclass - ) - if issubclass( - cls, andromede.expression.scenario_operator.ScenarioAggregator - ) - ] - if self.name not in valid_names: - raise ValueError( - f"{self.name} is not a valid scenario operator, valid scenario operators are {valid_names}" - ) - object.__setattr__(self, "instances", Instances.SIMPLE) - - -def sum_expressions(expressions: Sequence[ExpressionNode]) -> ExpressionNode: - if len(expressions) == 0: - return LiteralNode(0) - if len(expressions) == 1: - return expressions[0] - return expressions[0] + sum_expressions(expressions[1:]) diff --git a/src/andromede/expression/indexing.py b/src/andromede/expression/indexing.py index 102f4c45..73d43ff5 100644 --- a/src/andromede/expression/indexing.py +++ b/src/andromede/expression/indexing.py @@ -11,32 +11,9 @@ # This file is part of the Antares project. from abc import ABC, abstractmethod -from dataclasses import dataclass -import andromede.expression.time_operator from andromede.expression.indexing_structure import IndexingStructure -from .expression import ( - AdditionNode, - ComparisonNode, - ComponentParameterNode, - ComponentVariableNode, - DivisionNode, - ExpressionNode, - LiteralNode, - MultiplicationNode, - NegationNode, - ParameterNode, - PortFieldAggregatorNode, - PortFieldNode, - ScenarioOperatorNode, - SubstractionNode, - TimeAggregatorNode, - TimeOperatorNode, - VariableNode, -) -from .visitor import ExpressionVisitor, T, visit - class IndexingStructureProvider(ABC): @abstractmethod @@ -58,85 +35,3 @@ def get_component_parameter_structure( self, component_id: str, name: str ) -> IndexingStructure: ... - - -@dataclass(frozen=True) -class TimeScenarioIndexingVisitor(ExpressionVisitor[IndexingStructure]): - """ - Determines if the expression represents a single expression or an expression that should be instantiated for all time steps. - """ - - context: IndexingStructureProvider - - def literal(self, node: LiteralNode) -> IndexingStructure: - return IndexingStructure(False, False) - - def negation(self, node: NegationNode) -> IndexingStructure: - return visit(node.operand, self) - - def addition(self, node: AdditionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def substraction(self, node: SubstractionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def multiplication(self, node: MultiplicationNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def division(self, node: DivisionNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - def comparison(self, node: ComparisonNode) -> IndexingStructure: - return visit(node.left, self) | visit(node.right, self) - - # def variable(self, node: VariableNode) -> IndexingStructure: - # time = self.context.get_variable_structure(node.name).time == True - # scenario = self.context.get_variable_structure(node.name).scenario == True - # return IndexingStructure(time, scenario) - - def parameter(self, node: ParameterNode) -> IndexingStructure: - time = self.context.get_parameter_structure(node.name).time == True - scenario = self.context.get_parameter_structure(node.name).scenario == True - return IndexingStructure(time, scenario) - - # def comp_variable(self, node: ComponentVariableNode) -> IndexingStructure: - # return self.context.get_component_variable_structure( - # node.component_id, node.name - # ) - - def comp_parameter(self, node: ComponentParameterNode) -> IndexingStructure: - return self.context.get_component_parameter_structure( - node.component_id, node.name - ) - - def time_operator(self, node: TimeOperatorNode) -> IndexingStructure: - time_operator_cls = getattr(andromede.expression.time_operator, node.name) - if time_operator_cls.rolling(): - return visit(node.operand, self) - else: - return IndexingStructure(False, visit(node.operand, self).scenario) - - def time_aggregator(self, node: TimeAggregatorNode) -> IndexingStructure: - if node.stay_roll: - return visit(node.operand, self) - else: - return IndexingStructure(False, visit(node.operand, self).scenario) - - def scenario_operator(self, node: ScenarioOperatorNode) -> IndexingStructure: - return IndexingStructure(visit(node.operand, self).time, False) - - def port_field(self, node: PortFieldNode) -> IndexingStructure: - raise ValueError( - "Port fields must be resolved before computing indexing structure." - ) - - def port_field_aggregator(self, node: PortFieldAggregatorNode) -> IndexingStructure: - raise ValueError( - "Port fields aggregators must be resolved before computing indexing structure." - ) - - -def compute_indexation( - expression: ExpressionNode, provider: IndexingStructureProvider -) -> IndexingStructure: - return visit(expression, TimeScenarioIndexingVisitor(provider)) diff --git a/src/andromede/model/model.py b/src/andromede/model/model.py index 6a856e05..e151825c 100644 --- a/src/andromede/model/model.py +++ b/src/andromede/model/model.py @@ -19,16 +19,6 @@ from dataclasses import dataclass, field from typing import Dict, Iterable, Optional -# from andromede.expression.expression import ( -# BinaryOperatorNode, -# ComponentParameterNode, -# ComponentVariableNode, -# PortFieldAggregatorNode, -# PortFieldNode, -# ScenarioOperatorNode, -# TimeAggregatorNode, -# TimeOperatorNode, -# ) from andromede.expression.expression_efficient import ( AdditionNode, BinaryOperatorNode, @@ -46,7 +36,7 @@ TimeAggregatorNode, TimeOperatorNode, ) -from andromede.expression.indexing import IndexingStructureProvider, compute_indexation +from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, @@ -61,38 +51,6 @@ from andromede.model.port import PortType from andromede.model.variable import Variable -# from andromede.expression import ( -# AdditionNode, -# ComparisonNode, -# DivisionNode, -# ExpressionNode, -# ExpressionVisitor, -# LiteralNode, -# MultiplicationNode, -# NegationNode, -# ParameterNode, -# SubstractionNode, -# VariableNode, -# ) -# from andromede.expression.expression_efficient import ( -# AdditionNode, -# BinaryOperatorNode, -# ComparisonNode, -# ComponentParameterNode, -# DivisionNode, -# ExpressionNodeEfficient, -# LiteralNode, -# MultiplicationNode, -# NegationNode, -# ParameterNode, -# PortFieldAggregatorNode, -# PortFieldNode, -# ScenarioOperatorNode, -# SubstractionNode, -# TimeAggregatorNode, -# TimeOperatorNode, -# ) - # TODO: Introduce bool_variable ? def _make_structure_provider(model: "Model") -> IndexingStructureProvider: diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 7538b9e7..58e505d1 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -17,7 +17,6 @@ import math from dataclasses import dataclass -from typing import List, Optional import ortools.linear_solver.pywraplp as lp @@ -28,19 +27,16 @@ RowIndex, ) from andromede.expression.resolved_linear_expression import ResolvedLinearExpression -from andromede.expression.scenario_operator import Expectation -from andromede.expression.time_operator import TimeEvaluation, TimeShift, TimeSum from andromede.model.common import ValueType from andromede.model.constraint import Constraint from andromede.model.model import PortFieldId -from andromede.simulation.linear_expression import Term from andromede.simulation.linear_expression_resolver import LinearExpressionResolver from andromede.simulation.optimization_context import ( BlockBorderManagement, ComponentContext, OptimizationContext, - _make_data_structure_provider, - _make_value_provider, + make_data_structure_provider, + make_value_provider, ) from andromede.simulation.strategy import MergedProblemStrategy, ModelSelectionStrategy from andromede.simulation.time_block import TimeBlock @@ -61,7 +57,7 @@ def _get_indexing( def _compute_indexing_structure( context: ComponentContext, constraint: Constraint ) -> IndexingStructure: - data_structure_provider = _make_data_structure_provider( + data_structure_provider = make_data_structure_provider( context.opt_context.network, context.component ) constraint_indexing = _get_indexing(constraint, data_structure_provider) @@ -101,7 +97,7 @@ def _create_constraint( # instances_per_time_step = linear_expr.number_of_instances() # instances_per_time_step = 1 - value_provider = _make_value_provider(context.opt_context, context.component) + value_provider = make_value_provider(context.opt_context, context.component) expression_resolver = LinearExpressionResolver(context.opt_context, value_provider) for block_timestep in context.opt_context.get_time_indices(constraint_indexing): @@ -140,7 +136,7 @@ def _create_objective( ) # We have already checked in the model creation that the objective contribution is neither indexed by time nor by scenario - value_provider = _make_value_provider(opt_context, component) + value_provider = make_value_provider(opt_context, component) expression_resolver = LinearExpressionResolver(opt_context, value_provider) resolved_expr = expression_resolver.resolve(instantiated_expr, RowIndex(0, 0)) @@ -255,7 +251,7 @@ def _create_variables(self) -> None: component_context = self.context.get_component_context(component) model = component.model - value_provider = _make_value_provider(self.context, component) + value_provider = make_value_provider(self.context, component) expression_resolver = LinearExpressionResolver(self.context, value_provider) for model_var in self.strategy.get_variables(model): diff --git a/src/andromede/simulation/optimization_context.py b/src/andromede/simulation/optimization_context.py index 03b11ccc..85be98be 100644 --- a/src/andromede/simulation/optimization_context.py +++ b/src/andromede/simulation/optimization_context.py @@ -17,7 +17,6 @@ import ortools.linear_solver.pywraplp as lp -from andromede.expression import ParameterValueProvider, resolve_parameters from andromede.expression.evaluate_parameters_efficient import ValueProvider from andromede.expression.indexing import IndexingStructureProvider from andromede.expression.indexing_structure import IndexingStructure @@ -27,8 +26,6 @@ PortFieldKey, ) from andromede.expression.value_provider import TimeScenarioIndex, TimeScenarioIndices -from andromede.simulation.linear_expression import LinearExpression -from andromede.simulation.linearize import linearize_expression from andromede.simulation.time_block import TimeBlock from andromede.study.data import DataBase from andromede.study.network import Component, Network @@ -218,7 +215,7 @@ def _get_parameter_value( return data.get_value(absolute_timestep, scenario) -def _make_value_provider( +def make_value_provider( context: "OptimizationContext", component: Component, ) -> ValueProvider: @@ -306,39 +303,13 @@ class ExpressionTimestepValueProvider(TimestepValueProvider): # OptimizationContext has knowledge of the block, so that get_value only needs block_timestep and scenario to get the correct data value def get_value(self, block_timestep: int, scenario: int) -> float: - param_value_provider = _make_value_provider( + param_value_provider = make_value_provider( self.context, block_timestep, scenario, self.component ) return self.expression.evaluate(param_value_provider) -def _make_parameter_value_provider( - context: "OptimizationContext", - block_timestep: int, - scenario: int, -) -> ParameterValueProvider: - """ - A value provider which takes its values from - the parameter values as defined in the network data. - - Cannot evaluate expressions which contain variables. - """ - - class Provider(ParameterValueProvider): - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return _get_parameter_value( - context, block_timestep, scenario, component_id, name - ) - - def get_parameter_value(self, name: str) -> float: - raise ValueError( - "Parameters should have been associated with their component before resolution." - ) - - return Provider() - - -def _make_data_structure_provider( +def make_data_structure_provider( network: Network, component: Component ) -> IndexingStructureProvider: """ @@ -406,26 +377,6 @@ def get_variable( self.component.model.variables[variable_name].structure, ) - def linearize_expression( - self, - block_timestep: int, - scenario: int, - expression: LinearExpressionEfficient, - ) -> LinearExpression: - parameters_valued_provider = _make_parameter_value_provider( - self.opt_context, block_timestep, scenario - ) - evaluated_expr = resolve_parameters(expression, parameters_valued_provider) - - value_provider = _make_value_provider( - self.opt_context, block_timestep, scenario, self.component - ) - structure_provider = _make_data_structure_provider( - self.opt_context.network, self.component - ) - - return linearize_expression(evaluated_expr, structure_provider, value_provider) - def _get_data_time_key(block_timestep: int, data_indexing: IndexingStructure) -> int: return block_timestep if data_indexing.time else 0 diff --git a/tests/functional/test_performance.py b/tests/functional/test_performance.py deleted file mode 100644 index 1c50af1c..00000000 --- a/tests/functional/test_performance.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from typing import cast - -import pytest - -from andromede.expression.expression import ExpressionNode, literal, param, var -from andromede.expression.indexing_structure import IndexingStructure -from andromede.libs.standard import ( - BALANCE_PORT_TYPE, - DEMAND_MODEL, - GENERATOR_MODEL, - GENERATOR_MODEL_WITH_STORAGE, - NODE_BALANCE_MODEL, -) -from andromede.model import float_parameter, float_variable, model -from andromede.simulation import TimeBlock, build_problem -from andromede.study import ( - ConstantData, - DataBase, - Network, - Node, - PortRef, - create_component, -) -from tests.unittests.test_utils import generate_scalar_matrix_data - - -def test_large_sum_inside_model_with_loop() -> None: - """ - Test performance when the problem involves an expression with a high number of terms. - Here the objective function is the sum over nb_terms terms on a for-loop inside the model - - This test pass with 476 terms but fails with 477 locally due to recursion depth, - and even less terms are possible with Jenkins... - """ - nb_terms = 500 - - time_blocks = [TimeBlock(0, [0])] - scenarios = 1 - database = DataBase() - - for i in range(1, nb_terms): - database.add_data("simple_cost", f"cost_{i}", ConstantData(1 / i)) - - with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - SIMPLE_COST_MODEL = model( - id="SIMPLE_COST", - parameters=[ - float_parameter(f"cost_{i}", IndexingStructure(False, False)) - for i in range(1, nb_terms) - ], - objective_operational_contribution=cast( - ExpressionNode, sum(param(f"cost_{i}") for i in range(1, nb_terms)) - ), - ) - - # Won't run because last statement will raise the error - network = Network("test") - cost_model = create_component(model=SIMPLE_COST_MODEL, id="simple_cost") - network.add_component(cost_model) - - problem = build_problem(network, database, time_blocks[0], scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == sum( - [1 / i for i in range(1, nb_terms)] - ) - - -def test_large_sum_outside_model_with_loop() -> None: - """ - Test performance when the problem involves an expression with a high number of terms. - Here the objective function is the sum over nb_terms terms on a for-loop outside the model - """ - nb_terms = 10_000 - - time_blocks = [TimeBlock(0, [0])] - scenarios = 1 - database = DataBase() - - obj_coeff = sum([1 / i for i in range(1, nb_terms)]) - - SIMPLE_COST_MODEL = model( - id="SIMPLE_COST", - parameters=[], - objective_operational_contribution=literal(obj_coeff), - ) - - network = Network("test") - - simple_model = create_component( - model=SIMPLE_COST_MODEL, - id="simple_cost", - ) - network.add_component(simple_model) - - problem = build_problem(network, database, time_blocks[0], scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == obj_coeff - - -def test_large_sum_inside_model_with_sum_operator() -> None: - """ - Test performance when the problem involves an expression with a high number of terms. - Here the objective function is the sum over nb_terms terms with the sum() operator inside the model - """ - nb_terms = 10_000 - - scenarios = 1 - time_blocks = [TimeBlock(0, list(range(nb_terms)))] - database = DataBase() - - # Weird values when the "cost" varies over time and we use the sum() operator: - # For testing purposes, will use a const value since the problem seems to come when - # we try to linearize nb_terms variables with nb_terms distinct parameters - # TODO check the sum() operator for time-variable parameters - database.add_data("simple_cost", "cost", ConstantData(3)) - - SIMPLE_COST_MODEL = model( - id="SIMPLE_COST", - parameters=[ - float_parameter("cost", IndexingStructure(False, False)), - ], - variables=[ - float_variable( - "var", - lower_bound=literal(1), - upper_bound=literal(2), - structure=IndexingStructure(True, False), - ), - ], - objective_operational_contribution=(param("cost") * var("var")).sum(), - ) - - network = Network("test") - - cost_model = create_component(model=SIMPLE_COST_MODEL, id="simple_cost") - network.add_component(cost_model) - - problem = build_problem(network, database, time_blocks[0], scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 3 * nb_terms - - -def test_large_sum_of_port_connections() -> None: - """ - Test performance when the problem involves a model where several generators are connected to a node. - - This test pass with 470 terms but fails with 471 locally due to recursion depth, - and possibly even less terms are possible with Jenkins... - """ - nb_generators = 500 - - time_block = TimeBlock(0, [0]) - scenarios = 1 - - database = DataBase() - database.add_data("D", "demand", ConstantData(nb_generators)) - - for gen_id in range(nb_generators): - database.add_data(f"G_{gen_id}", "p_max", ConstantData(1)) - database.add_data(f"G_{gen_id}", "cost", ConstantData(5)) - - node = Node(model=NODE_BALANCE_MODEL, id="N") - demand = create_component(model=DEMAND_MODEL, id="D") - generators = [ - create_component(model=GENERATOR_MODEL, id=f"G_{gen_id}") - for gen_id in range(nb_generators) - ] - - network = Network("test") - network.add_node(node) - - network.add_component(demand) - network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) - - for gen_id in range(nb_generators): - network.add_component(generators[gen_id]) - network.connect( - PortRef(generators[gen_id], "balance_port"), PortRef(node, "balance_port") - ) - - with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - problem = build_problem(network, database, time_block, scenarios) - - # Won't run because last statement will raise the error - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 5 * nb_generators - - -def test_basic_balance_on_whole_year() -> None: - """ - Balance on one node with one fixed demand and one generation, on 8760 timestep. - """ - - scenarios = 1 - horizon = 8760 - time_block = TimeBlock(1, list(range(horizon))) - - database = DataBase() - database.add_data( - "D", "demand", generate_scalar_matrix_data(100, horizon, scenarios) - ) - - database.add_data("G", "p_max", ConstantData(100)) - database.add_data("G", "cost", ConstantData(30)) - - node = Node(model=NODE_BALANCE_MODEL, id="N") - demand = create_component(model=DEMAND_MODEL, id="D") - - gen = create_component(model=GENERATOR_MODEL, id="G") - - network = Network("test") - network.add_node(node) - network.add_component(demand) - network.add_component(gen) - network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) - network.connect(PortRef(gen, "balance_port"), PortRef(node, "balance_port")) - - problem = build_problem(network, database, time_block, scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 30 * 100 * horizon - - -def test_basic_balance_on_whole_year_with_large_sum() -> None: - """ - Balance on one node with one fixed demand and one generation with storage, on 8760 timestep. - """ - - scenarios = 1 - horizon = 8760 - time_block = TimeBlock(1, list(range(horizon))) - - database = DataBase() - database.add_data( - "D", "demand", generate_scalar_matrix_data(100, horizon, scenarios) - ) - - database.add_data("G", "p_max", ConstantData(100)) - database.add_data("G", "cost", ConstantData(30)) - database.add_data("G", "full_storage", ConstantData(100 * horizon)) - - node = Node(model=NODE_BALANCE_MODEL, id="N") - demand = create_component(model=DEMAND_MODEL, id="D") - gen = create_component( - model=GENERATOR_MODEL_WITH_STORAGE, id="G" - ) # Limits the total generation inside a TimeBlock - - network = Network("test") - network.add_node(node) - network.add_component(demand) - network.add_component(gen) - network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) - network.connect(PortRef(gen, "balance_port"), PortRef(node, "balance_port")) - - problem = build_problem(network, database, time_block, scenarios) - status = problem.solver.Solve() - - assert status == problem.solver.OPTIMAL - assert problem.solver.Objective().Value() == 30 * 100 * horizon diff --git a/tests/functional/test_performance_efficient.py b/tests/functional/test_performance_efficient.py index 55ef9913..ed00bd02 100644 --- a/tests/functional/test_performance_efficient.py +++ b/tests/functional/test_performance_efficient.py @@ -21,6 +21,17 @@ var, wrap_in_linear_expr, ) +from andromede.libs.standard import ( + DEMAND_MODEL, + GENERATOR_MODEL, + GENERATOR_MODEL_WITH_STORAGE, + NODE_BALANCE_MODEL, +) +from andromede.simulation.optimization import build_problem +from andromede.simulation.time_block import TimeBlock +from andromede.study.data import ConstantData, DataBase +from andromede.study.network import Network, Node, PortRef, create_component +from tests.unittests.test_utils import generate_scalar_matrix_data def test_large_number_of_parameters_sum() -> None: @@ -84,3 +95,124 @@ def test_large_number_of_variables_sum() -> None: assert expr.evaluate( EvaluationContext(variables=variables_value), RowIndex(0, 0) ) == sum(1 / i for i in range(1, nb_terms)) + + +def test_large_sum_of_port_connections() -> None: + """ + Test performance when the problem involves a model where several generators are connected to a node. + + This test pass with 470 terms but fails with 471 locally due to recursion depth, + and possibly even less terms are possible with Jenkins... + """ + nb_generators = 500 + + time_block = TimeBlock(0, [0]) + scenarios = 1 + + database = DataBase() + database.add_data("D", "demand", ConstantData(nb_generators)) + + for gen_id in range(nb_generators): + database.add_data(f"G_{gen_id}", "p_max", ConstantData(1)) + database.add_data(f"G_{gen_id}", "cost", ConstantData(5)) + + node = Node(model=NODE_BALANCE_MODEL, id="N") + demand = create_component(model=DEMAND_MODEL, id="D") + generators = [ + create_component(model=GENERATOR_MODEL, id=f"G_{gen_id}") + for gen_id in range(nb_generators) + ] + + network = Network("test") + network.add_node(node) + + network.add_component(demand) + network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) + + for gen_id in range(nb_generators): + network.add_component(generators[gen_id]) + network.connect( + PortRef(generators[gen_id], "balance_port"), PortRef(node, "balance_port") + ) + + # Raised recursion error with previous implementation + problem = build_problem(network, database, time_block, scenarios) + + status = problem.solver.Solve() + + assert status == problem.solver.OPTIMAL + assert problem.solver.Objective().Value() == 5 * nb_generators + + +def test_basic_balance_on_whole_year() -> None: + """ + Balance on one node with one fixed demand and one generation, on 8760 timestep. + """ + + scenarios = 1 + horizon = 8760 + time_block = TimeBlock(1, list(range(horizon))) + + database = DataBase() + database.add_data( + "D", "demand", generate_scalar_matrix_data(100, horizon, scenarios) + ) + + database.add_data("G", "p_max", ConstantData(100)) + database.add_data("G", "cost", ConstantData(30)) + + node = Node(model=NODE_BALANCE_MODEL, id="N") + demand = create_component(model=DEMAND_MODEL, id="D") + + gen = create_component(model=GENERATOR_MODEL, id="G") + + network = Network("test") + network.add_node(node) + network.add_component(demand) + network.add_component(gen) + network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) + network.connect(PortRef(gen, "balance_port"), PortRef(node, "balance_port")) + + problem = build_problem(network, database, time_block, scenarios) + status = problem.solver.Solve() + + assert status == problem.solver.OPTIMAL + assert problem.solver.Objective().Value() == 30 * 100 * horizon + + +def test_basic_balance_on_whole_year_with_large_sum() -> None: + """ + Balance on one node with one fixed demand and one generation with storage, on 8760 timestep. + """ + + scenarios = 1 + horizon = 8760 + time_block = TimeBlock(1, list(range(horizon))) + + database = DataBase() + database.add_data( + "D", "demand", generate_scalar_matrix_data(100, horizon, scenarios) + ) + + database.add_data("G", "p_max", ConstantData(100)) + database.add_data("G", "cost", ConstantData(30)) + database.add_data("G", "full_storage", ConstantData(100 * horizon)) + + node = Node(model=NODE_BALANCE_MODEL, id="N") + demand = create_component(model=DEMAND_MODEL, id="D") + gen = create_component( + model=GENERATOR_MODEL_WITH_STORAGE, id="G" + ) # Limits the total generation inside a TimeBlock + + network = Network("test") + network.add_node(node) + network.add_component(demand) + network.add_component(gen) + network.connect(PortRef(demand, "balance_port"), PortRef(node, "balance_port")) + network.connect(PortRef(gen, "balance_port"), PortRef(node, "balance_port")) + + problem = build_problem(network, database, time_block, scenarios) + status = problem.solver.Solve() + + assert status == problem.solver.OPTIMAL + assert problem.solver.Objective().Value() == 30 * 100 * horizon diff --git a/tests/unittests/expressions/test_expressions.py b/tests/unittests/expressions/test_expressions.py deleted file mode 100644 index 9c415c18..00000000 --- a/tests/unittests/expressions/test_expressions.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright (c) 2024, RTE (https://www.rte-france.com) -# -# See AUTHORS.txt -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# SPDX-License-Identifier: MPL-2.0 -# -# This file is part of the Antares project. - -from dataclasses import dataclass, field -from typing import Dict - -import pandas as pd -import pytest - -from andromede.expression import ( - AdditionNode, - DivisionNode, - EvaluationContext, - EvaluationVisitor, - ExpressionDegreeVisitor, - ExpressionNode, - LiteralNode, - ParameterNode, - ParameterValueProvider, - PrinterVisitor, - ValueProvider, - VariableNode, - literal, - param, - resolve_parameters, - sum_expressions, - var, - visit, -) -from andromede.expression.equality import expressions_equal -from andromede.expression.expression import ( - ComponentParameterNode, - ComponentVariableNode, - ExpressionRange, - Instances, - comp_param, - comp_var, - port_field, -) -from andromede.expression.indexing import IndexingStructureProvider, compute_indexation -from andromede.expression.indexing_structure import IndexingStructure -from andromede.model.model import PortFieldId -from andromede.simulation.linear_expression import LinearExpression, Term -from andromede.simulation.linearize import linearize_expression - - -@dataclass(frozen=True) -class ComponentValueKey: - component_id: str - variable_name: str - - -def comp_key(component_id: str, variable_name: str) -> ComponentValueKey: - return ComponentValueKey(component_id, variable_name) - - -@dataclass(frozen=True) -class ComponentEvaluationContext(ValueProvider): - """ - Simple value provider relying on dictionaries. - Does not support component variables/parameters. - """ - - variables: Dict[ComponentValueKey, float] = field(default_factory=dict) - parameters: Dict[ComponentValueKey, float] = field(default_factory=dict) - - def get_variable_value(self, name: str) -> float: - raise NotImplementedError() - - def get_parameter_value(self, name: str) -> float: - raise NotImplementedError() - - def get_component_variable_value(self, component_id: str, name: str) -> float: - return self.variables[comp_key(component_id, name)] - - def get_component_parameter_value(self, component_id: str, name: str) -> float: - return self.parameters[comp_key(component_id, name)] - - def parameter_is_constant_over_time(self, name: str) -> bool: - raise NotImplementedError() - - -def test_comp_parameter() -> None: - add_node = AdditionNode(LiteralNode(1), ComponentVariableNode("comp1", "x")) - expr = DivisionNode(add_node, ComponentParameterNode("comp1", "p")) - - assert visit(expr, PrinterVisitor()) == "((1 + comp1.x) / comp1.p)" - - context = ComponentEvaluationContext( - variables={comp_key("comp1", "x"): 3}, parameters={comp_key("comp1", "p"): 4} - ) - assert visit(expr, EvaluationVisitor(context)) == 1 - - -def test_ast() -> None: - add_node = AdditionNode(LiteralNode(1), VariableNode("x")) - expr = DivisionNode(add_node, ParameterNode("p")) - - assert visit(expr, PrinterVisitor()) == "((1 + x) / p)" - - context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) - assert visit(expr, EvaluationVisitor(context)) == 1 - - -def test_operators() -> None: - x = var("x") - p = param("p") - expr: ExpressionNode = (5 * x + 3) / p - 2 - - assert visit(expr, PrinterVisitor()) == "((((5.0 * x) + 3.0) / p) - 2.0)" - - context = EvaluationContext(variables={"x": 3}, parameters={"p": 4}) - assert visit(expr, EvaluationVisitor(context)) == pytest.approx(2.5, 1e-16) - - assert visit(-expr, EvaluationVisitor(context)) == pytest.approx(-2.5, 1e-16) - - -def test_degree() -> None: - x = var("x") - p = param("p") - expr = (5 * x + 3) / p - - assert visit(expr, ExpressionDegreeVisitor()) == 1 - - expr = x * expr - assert visit(expr, ExpressionDegreeVisitor()) == 2 - - -@pytest.mark.xfail(reason="Degree simplification not implemented") -def test_degree_computation_should_take_into_account_simplifications() -> None: - x = var("x") - expr = x - x - assert visit(expr, ExpressionDegreeVisitor()) == 0 - - expr = LiteralNode(0) * x - assert visit(expr, ExpressionDegreeVisitor()) == 0 - - -def test_parameters_resolution() -> None: - class TestParamProvider(ParameterValueProvider): - def get_component_parameter_value(self, component_id: str, name: str) -> float: - raise NotImplementedError() - - def get_parameter_value(self, name: str) -> float: - return 2 - - x = var("x") - p = param("p") - expr = (5 * x + 3) / p - assert resolve_parameters(expr, TestParamProvider()) == (5 * x + 3) / 2 - - -def test_linearization() -> None: - x = comp_var("c", "x") - expr = (5 * x + 3) / 2 - provider = StructureProvider() - - assert linearize_expression(expr, provider) == LinearExpression( - [Term(2.5, "c", "x")], 1.5 - ) - - with pytest.raises(ValueError): - linearize_expression(param("p") * x, provider) - - -def test_linearization_of_non_linear_expressions_should_raise_value_error() -> None: - x = var("x") - expr = x.variance() - - provider = StructureProvider() - with pytest.raises(ValueError) as exc: - linearize_expression(expr, provider) - assert ( - str(exc.value) - == "Cannot linearize expression with a non-linear operator: Variance" - ) - - -def test_comparison() -> None: - x = var("x") - p = param("p") - expr: ExpressionNode = (5 * x + 3) >= p - 2 - - assert visit(expr, PrinterVisitor()) == "((5.0 * x) + 3.0) >= (p - 2.0)" - - -class StructureProvider(IndexingStructureProvider): - def get_component_variable_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - return IndexingStructure(True, True) - - def get_component_parameter_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - return IndexingStructure(True, True) - - def get_parameter_structure(self, name: str) -> IndexingStructure: - return IndexingStructure(True, True) - - def get_variable_structure(self, name: str) -> IndexingStructure: - return IndexingStructure(True, True) - - -def test_shift() -> None: - x = var("x") - expr = x.shift(ExpressionRange(literal(1), literal(4))) - - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, True) - assert expr.instances == Instances.MULTIPLE - - -def test_shifting_sum() -> None: - x = var("x") - expr = x.shift(ExpressionRange(literal(1), literal(4))).sum() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, True) - assert expr.instances == Instances.SIMPLE - - -def test_eval() -> None: - x = var("x") - expr = x.eval(ExpressionRange(literal(1), literal(4))) - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(False, True) - assert expr.instances == Instances.MULTIPLE - - -def test_eval_sum() -> None: - x = var("x") - expr = x.eval(ExpressionRange(literal(1), literal(4))).sum() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(False, True) - assert expr.instances == Instances.SIMPLE - - -def test_sum_over_whole_block() -> None: - x = var("x") - expr = x.sum() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(False, True) - assert expr.instances == Instances.SIMPLE - - -def test_forbidden_composition_should_raise_value_error() -> None: - x = var("x") - with pytest.raises(ValueError): - _ = x.shift(ExpressionRange(literal(1), literal(4))) + var("y") - - -def test_expectation() -> None: - x = var("x") - expr = x.expec() - provider = StructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, False) - assert expr.instances == Instances.SIMPLE - - -def test_indexing_structure_comparison() -> None: - free = IndexingStructure(True, True) - constant = IndexingStructure(False, False) - assert free | constant == IndexingStructure(True, True) - - -def test_multiplication_of_differently_indexed_terms() -> None: - x = var("x") - p = param("p") - expr = p * x - - class CustomStructureProvider(IndexingStructureProvider): - def get_component_variable_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - raise NotImplementedError() - - def get_component_parameter_structure( - self, component_id: str, name: str - ) -> IndexingStructure: - raise NotImplementedError() - - def get_parameter_structure(self, name: str) -> IndexingStructure: - return IndexingStructure(False, False) - - def get_variable_structure(self, name: str) -> IndexingStructure: - return IndexingStructure(True, True) - - provider = CustomStructureProvider() - - assert compute_indexation(expr, provider) == IndexingStructure(True, True) - - -def test_sum_expressions() -> None: - assert expressions_equal(sum_expressions([]), literal(0)) - assert expressions_equal(sum_expressions([literal(1)]), literal(1)) - assert expressions_equal(sum_expressions([literal(1), var("x")]), 1 + var("x")) - assert expressions_equal( - sum_expressions([literal(1), var("x"), param("p")]), 1 + (var("x") + param("p")) - )