From fd660d909fa69f1bd1e85e374c4d20bfcbcdb9be Mon Sep 17 00:00:00 2001 From: Thomas Bittar Date: Wed, 21 Aug 2024 17:36:03 +0200 Subject: [PATCH] Fix shift and eval parsing (temporary) --- .../expression/expression_efficient.py | 33 ++++++++++++++- .../expression/linear_expression_efficient.py | 27 ++++++++---- .../expression/parsing/parse_expression.py | 4 +- .../parsing/test_expression_parsing.py | 41 +++++++++++++------ 4 files changed, 81 insertions(+), 24 deletions(-) diff --git a/src/andromede/expression/expression_efficient.py b/src/andromede/expression/expression_efficient.py index a0fd86e..29bbe87 100644 --- a/src/andromede/expression/expression_efficient.py +++ b/src/andromede/expression/expression_efficient.py @@ -444,7 +444,7 @@ class DivisionNode(BinaryOperatorNode): pass -@dataclass(frozen=True, eq=False) +@dataclass(frozen=True) class ExpressionRange: start: ExpressionNodeEfficient stop: ExpressionNodeEfficient @@ -457,6 +457,14 @@ def __post_init__(self) -> None: self, attribute, wrap_in_node(value) if value is not None else value ) + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, ExpressionRange) + and expressions_equal(self.start, other.start) + and expressions_equal(self.stop, other.stop) + and expressions_equal_if_present(self.step, other.step) + ) + IntOrExpr = Union[int, ExpressionNodeEfficient] @@ -515,6 +523,29 @@ def __hash__(self) -> int: else: return hash(self.expressions) + def __eq__(self, other: Any) -> bool: + if isinstance(other, InstancesTimeIndex): + if isinstance(self.expressions, list) and all( + isinstance(x, ExpressionNodeEfficient) for x in self.expressions + ): + return ( + isinstance(other.expressions, list) + and all( + isinstance(x, ExpressionNodeEfficient) + for x in other.expressions + ) + and all( + expressions_equal(left_expr, right_expr) + for left_expr, right_expr in zip( + self.expressions, other.expressions + ) + ) + ) + elif isinstance(self.expressions, ExpressionRange): + return self.expressions == other.expressions + else: + return False + def is_simple(self) -> bool: if isinstance(self.expressions, list): return len(self.expressions) == 1 diff --git a/src/andromede/expression/linear_expression_efficient.py b/src/andromede/expression/linear_expression_efficient.py index 13444d8..87766bc 100644 --- a/src/andromede/expression/linear_expression_efficient.py +++ b/src/andromede/expression/linear_expression_efficient.py @@ -29,14 +29,10 @@ overload, ) -import ortools.linear_solver.pywraplp as lp - from andromede.expression.context_adder import add_component_context from andromede.expression.equality import expressions_equal -from andromede.expression.evaluate import evaluate from andromede.expression.evaluate_parameters_efficient import ( check_resolved_expr, - get_time_ids_from_instances_index, resolve_coefficient, ) from andromede.expression.expression_efficient import ( @@ -61,10 +57,6 @@ from andromede.expression.indexing_structure import IndexingStructure, RowIndex from andromede.expression.port_operator import PortAggregator, PortSum from andromede.expression.print import print_expr -from andromede.expression.resolved_linear_expression import ( - ResolvedLinearExpression, - ResolvedTerm, -) from andromede.expression.scenario_operator import Expectation, ScenarioAggregator from andromede.expression.time_operator import ( TimeAggregator, @@ -92,6 +84,17 @@ class TermKeyEfficient: time_aggregator: Optional[TimeAggregator] scenario_aggregator: Optional[ScenarioAggregator] + # Used for test_expression_parsing + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, TermKeyEfficient) + and self.component_id == other.component_id + and self.variable_name == other.variable_name + and self.time_operator == other.time_operator + and self.time_aggregator == other.time_aggregator + and self.scenario_aggregator == other.scenario_aggregator + ) + @dataclass(frozen=True) class TermEfficient: @@ -1065,6 +1068,14 @@ def __post_init__( def __str__(self) -> str: return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}" + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, StandaloneConstraint) + and linear_expressions_equal_if_present(self.expression, other.expression) + and linear_expressions_equal_if_present(self.lower_bound, other.lower_bound) + and linear_expressions_equal_if_present(self.upper_bound, other.upper_bound) + ) + def wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient: if isinstance(obj, LinearExpressionEfficient): diff --git a/src/andromede/expression/parsing/parse_expression.py b/src/andromede/expression/parsing/parse_expression.py index e694144..bb20c66 100644 --- a/src/andromede/expression/parsing/parse_expression.py +++ b/src/andromede/expression/parsing/parse_expression.py @@ -168,7 +168,7 @@ def visitTimeShift( # specifics for x[t] ... if len(time_shifts) == 1 and expressions_equal(time_shifts[0], literal(0)): return shifted_expr - return shifted_expr.shift(time_shifts) + return shifted_expr.sum(shift=time_shifts) def visitTimeShiftRange( self, ctx: ExprParser.TimeShiftRangeContext @@ -176,7 +176,7 @@ def visitTimeShiftRange( shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore shift1 = ctx.shift1.accept(self) # type: ignore shift2 = ctx.shift2.accept(self) # type: ignore - return shifted_expr.shift(ExpressionRange(shift1, shift2)) + return shifted_expr.sum(shift=ExpressionRange(shift1, shift2)) # Visit a parse tree produced by ExprParser#function. def visitFunction( diff --git a/tests/unittests/expressions/parsing/test_expression_parsing.py b/tests/unittests/expressions/parsing/test_expression_parsing.py index 7992241..cec5b36 100644 --- a/tests/unittests/expressions/parsing/test_expression_parsing.py +++ b/tests/unittests/expressions/parsing/test_expression_parsing.py @@ -9,14 +9,21 @@ # SPDX-License-Identifier: MPL-2.0 # # This file is part of the Antares project. -from typing import Set +from typing import Set, Union import pytest from andromede.expression.equality import expressions_equal -from andromede.expression.expression_efficient import ExpressionRange, literal, param +from andromede.expression.expression_efficient import ( + ExpressionNodeEfficient, + ExpressionRange, + literal, + param, +) from andromede.expression.linear_expression_efficient import ( LinearExpressionEfficient, + StandaloneConstraint, + linear_expressions_equal, port_field, var, ) @@ -59,14 +66,14 @@ ( {"x"}, {}, - "x[t-1, t+4]", + "x[t-1, t+4]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=[-literal(1), literal(4)]), ), ( {"x"}, {}, "x[t-1+1]", - var("x").sum(shift=-literal(1) + literal(1)), + var("x"), # Simplifications are applied very early in parsing !!!! ), ( {"x"}, @@ -95,25 +102,25 @@ ( {"x"}, {}, - "x[t-1, t, t+4]", + "x[t-1, t, t+4]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=[-literal(1), literal(0), literal(4)]), ), ( {"x"}, {}, - "x[t-1..t+5]", + "x[t-1..t+5]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=ExpressionRange(-literal(1), literal(5))), ), ( {"x"}, {}, - "x[t-1..t]", + "x[t-1..t]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=ExpressionRange(-literal(1), literal(0))), ), ( {"x"}, {}, - "x[t..t+5]", + "x[t..t+5]", # TODO: Should raise ValueError: shift always with sum var("x").sum(shift=ExpressionRange(literal(0), literal(5))), ), ({"x"}, {}, "x[t]", var("x")), @@ -122,7 +129,7 @@ {"x"}, {}, "sum(x[-1..5])", - var("x").sum(eval=ExpressionRange(-literal(1), literal(5))).sum(), + var("x").sum(eval=ExpressionRange(-literal(1), literal(5))), ), ({}, {}, "sum_connections(port.f)", port_field("port", "f").sum_connections()), ( @@ -156,13 +163,21 @@ def test_parsing_visitor( variables: Set[str], parameters: Set[str], expression_str: str, - expected: LinearExpressionEfficient, -): + expected: Union[ + ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint + ], +) -> None: identifiers = ModelIdentifiers(variables, parameters) expr = parse_expression(expression_str, identifiers) print() - print(print_expr(expr)) - assert expressions_equal(expr, expected) + print(f"Expected: \n {str(expected)}") + print(f"Parsed: \n {str(expr)}") + if isinstance(expected, ExpressionNodeEfficient): + assert expressions_equal(expr, expected) + elif isinstance(expected, LinearExpressionEfficient): + assert linear_expressions_equal(expr, expected) + elif isinstance(expected, StandaloneConstraint): + assert expected == expr @pytest.mark.parametrize(