Skip to content

Commit

Permalink
Fix shift and eval parsing (temporary)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbittar committed Aug 21, 2024
1 parent 4ad0c89 commit fd660d9
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 24 deletions.
33 changes: 32 additions & 1 deletion src/andromede/expression/expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class DivisionNode(BinaryOperatorNode):
pass


@dataclass(frozen=True, eq=False)
@dataclass(frozen=True)
class ExpressionRange:
start: ExpressionNodeEfficient
stop: ExpressionNodeEfficient
Expand All @@ -457,6 +457,14 @@ def __post_init__(self) -> None:
self, attribute, wrap_in_node(value) if value is not None else value
)

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, ExpressionRange)
and expressions_equal(self.start, other.start)
and expressions_equal(self.stop, other.stop)
and expressions_equal_if_present(self.step, other.step)
)


IntOrExpr = Union[int, ExpressionNodeEfficient]

Expand Down Expand Up @@ -515,6 +523,29 @@ def __hash__(self) -> int:
else:
return hash(self.expressions)

def __eq__(self, other: Any) -> bool:
if isinstance(other, InstancesTimeIndex):
if isinstance(self.expressions, list) and all(
isinstance(x, ExpressionNodeEfficient) for x in self.expressions
):
return (
isinstance(other.expressions, list)
and all(
isinstance(x, ExpressionNodeEfficient)
for x in other.expressions
)
and all(
expressions_equal(left_expr, right_expr)
for left_expr, right_expr in zip(
self.expressions, other.expressions
)
)
)
elif isinstance(self.expressions, ExpressionRange):
return self.expressions == other.expressions
else:
return False

def is_simple(self) -> bool:
if isinstance(self.expressions, list):
return len(self.expressions) == 1
Expand Down
27 changes: 19 additions & 8 deletions src/andromede/expression/linear_expression_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,10 @@
overload,
)

import ortools.linear_solver.pywraplp as lp

from andromede.expression.context_adder import add_component_context
from andromede.expression.equality import expressions_equal
from andromede.expression.evaluate import evaluate
from andromede.expression.evaluate_parameters_efficient import (
check_resolved_expr,
get_time_ids_from_instances_index,
resolve_coefficient,
)
from andromede.expression.expression_efficient import (
Expand All @@ -61,10 +57,6 @@
from andromede.expression.indexing_structure import IndexingStructure, RowIndex
from andromede.expression.port_operator import PortAggregator, PortSum
from andromede.expression.print import print_expr
from andromede.expression.resolved_linear_expression import (
ResolvedLinearExpression,
ResolvedTerm,
)
from andromede.expression.scenario_operator import Expectation, ScenarioAggregator
from andromede.expression.time_operator import (
TimeAggregator,
Expand Down Expand Up @@ -92,6 +84,17 @@ class TermKeyEfficient:
time_aggregator: Optional[TimeAggregator]
scenario_aggregator: Optional[ScenarioAggregator]

# Used for test_expression_parsing
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, TermKeyEfficient)
and self.component_id == other.component_id
and self.variable_name == other.variable_name
and self.time_operator == other.time_operator
and self.time_aggregator == other.time_aggregator
and self.scenario_aggregator == other.scenario_aggregator
)


@dataclass(frozen=True)
class TermEfficient:
Expand Down Expand Up @@ -1065,6 +1068,14 @@ def __post_init__(
def __str__(self) -> str:
return f"{str(self.lower_bound)} <= {str(self.expression)} <= {str(self.upper_bound)}"

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, StandaloneConstraint)
and linear_expressions_equal_if_present(self.expression, other.expression)
and linear_expressions_equal_if_present(self.lower_bound, other.lower_bound)
and linear_expressions_equal_if_present(self.upper_bound, other.upper_bound)
)


def wrap_in_linear_expr(obj: Any) -> LinearExpressionEfficient:
if isinstance(obj, LinearExpressionEfficient):
Expand Down
4 changes: 2 additions & 2 deletions src/andromede/expression/parsing/parse_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ def visitTimeShift(
# specifics for x[t] ...
if len(time_shifts) == 1 and expressions_equal(time_shifts[0], literal(0)):
return shifted_expr
return shifted_expr.shift(time_shifts)
return shifted_expr.sum(shift=time_shifts)

def visitTimeShiftRange(
self, ctx: ExprParser.TimeShiftRangeContext
) -> LinearExpressionEfficient:
shifted_expr = self._convert_identifier(ctx.IDENTIFIER().getText()) # type: ignore
shift1 = ctx.shift1.accept(self) # type: ignore
shift2 = ctx.shift2.accept(self) # type: ignore
return shifted_expr.shift(ExpressionRange(shift1, shift2))
return shifted_expr.sum(shift=ExpressionRange(shift1, shift2))

# Visit a parse tree produced by ExprParser#function.
def visitFunction(
Expand Down
41 changes: 28 additions & 13 deletions tests/unittests/expressions/parsing/test_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.
from typing import Set
from typing import Set, Union

import pytest

from andromede.expression.equality import expressions_equal
from andromede.expression.expression_efficient import ExpressionRange, literal, param
from andromede.expression.expression_efficient import (
ExpressionNodeEfficient,
ExpressionRange,
literal,
param,
)
from andromede.expression.linear_expression_efficient import (
LinearExpressionEfficient,
StandaloneConstraint,
linear_expressions_equal,
port_field,
var,
)
Expand Down Expand Up @@ -59,14 +66,14 @@
(
{"x"},
{},
"x[t-1, t+4]",
"x[t-1, t+4]", # TODO: Should raise ValueError: shift always with sum
var("x").sum(shift=[-literal(1), literal(4)]),
),
(
{"x"},
{},
"x[t-1+1]",
var("x").sum(shift=-literal(1) + literal(1)),
var("x"), # Simplifications are applied very early in parsing !!!!
),
(
{"x"},
Expand Down Expand Up @@ -95,25 +102,25 @@
(
{"x"},
{},
"x[t-1, t, t+4]",
"x[t-1, t, t+4]", # TODO: Should raise ValueError: shift always with sum
var("x").sum(shift=[-literal(1), literal(0), literal(4)]),
),
(
{"x"},
{},
"x[t-1..t+5]",
"x[t-1..t+5]", # TODO: Should raise ValueError: shift always with sum
var("x").sum(shift=ExpressionRange(-literal(1), literal(5))),
),
(
{"x"},
{},
"x[t-1..t]",
"x[t-1..t]", # TODO: Should raise ValueError: shift always with sum
var("x").sum(shift=ExpressionRange(-literal(1), literal(0))),
),
(
{"x"},
{},
"x[t..t+5]",
"x[t..t+5]", # TODO: Should raise ValueError: shift always with sum
var("x").sum(shift=ExpressionRange(literal(0), literal(5))),
),
({"x"}, {}, "x[t]", var("x")),
Expand All @@ -122,7 +129,7 @@
{"x"},
{},
"sum(x[-1..5])",
var("x").sum(eval=ExpressionRange(-literal(1), literal(5))).sum(),
var("x").sum(eval=ExpressionRange(-literal(1), literal(5))),
),
({}, {}, "sum_connections(port.f)", port_field("port", "f").sum_connections()),
(
Expand Down Expand Up @@ -156,13 +163,21 @@ def test_parsing_visitor(
variables: Set[str],
parameters: Set[str],
expression_str: str,
expected: LinearExpressionEfficient,
):
expected: Union[
ExpressionNodeEfficient, LinearExpressionEfficient, StandaloneConstraint
],
) -> None:
identifiers = ModelIdentifiers(variables, parameters)
expr = parse_expression(expression_str, identifiers)
print()
print(print_expr(expr))
assert expressions_equal(expr, expected)
print(f"Expected: \n {str(expected)}")
print(f"Parsed: \n {str(expr)}")
if isinstance(expected, ExpressionNodeEfficient):
assert expressions_equal(expr, expected)
elif isinstance(expected, LinearExpressionEfficient):
assert linear_expressions_equal(expr, expected)
elif isinstance(expected, StandaloneConstraint):
assert expected == expr


@pytest.mark.parametrize(
Expand Down

0 comments on commit fd660d9

Please sign in to comment.