diff --git a/src/andromede/simulation/decision_tree.py b/src/andromede/simulation/decision_tree.py index c39e2da..6f67f7d 100644 --- a/src/andromede/simulation/decision_tree.py +++ b/src/andromede/simulation/decision_tree.py @@ -23,7 +23,13 @@ from andromede.model.model import model from andromede.model.variable import Variable, float_variable from andromede.simulation.time_block import TimeBlock -from andromede.study.network import Component, Network, create_component +from andromede.study.network import ( + Component, + Network, + PortRef, + build_ports_connection, + create_component, +) @dataclass(frozen=True) @@ -81,6 +87,25 @@ def is_leaves_prob_sum_one(self) -> bool: # probability sum equal to one return all(child.is_leaves_prob_sum_one() for child in self.children) + def connect_from_parent(self, port: PortRef, parent_port: PortRef) -> None: + if self.parent is None: + raise RuntimeError("Cannot connect upwards because no parent is defined") + + ports_connection = build_ports_connection( + port, parent_port, self.id, self.parent.id + ) + self.network._connections.append(ports_connection) + + def connect_to_children(self, port: PortRef, children_port: PortRef) -> None: + if not self.children: + raise RuntimeError("Cannot connect downwards because no child is defined") + + for child in self.children: + ports_connection = build_ports_connection( + port, children_port, self.id, child.id + ) + child.network._connections.append(ports_connection) + def add_coupling_component( self, component: Component, diff --git a/src/andromede/simulation/optimization.py b/src/andromede/simulation/optimization.py index 3b22a58..9450a56 100644 --- a/src/andromede/simulation/optimization.py +++ b/src/andromede/simulation/optimization.py @@ -725,6 +725,16 @@ def _register_connection_fields_definitions(self) -> None: master_port.component.id, port_definition.definition ) + if cnx.context2: + instantiated_expression = add_decision_tree_context( + cnx.context2, instantiated_expression + ) + + elif self.context.tree_node: + instantiated_expression = add_decision_tree_context( + self.context.tree_node, instantiated_expression + ) + self.context.register_connection_fields_expressions( component_id=cnx.port1.component.id, port_name=cnx.port1.port_id, diff --git a/src/andromede/study/network.py b/src/andromede/study/network.py index c97babb..4aee87f 100644 --- a/src/andromede/study/network.py +++ b/src/andromede/study/network.py @@ -16,7 +16,7 @@ """ import itertools from dataclasses import dataclass, field, replace -from typing import Any, Dict, Iterable, List, cast +from typing import Any, Dict, Iterable, List, Optional, cast from andromede.model import PortField, PortType from andromede.model.model import Model @@ -72,7 +72,9 @@ def __repr__(self) -> str: @dataclass() class PortsConnection: + context1: Optional[str] port1: PortRef + context2: Optional[str] port2: PortRef master_port: Dict[PortField, PortRef] = field( init=False, default_factory=dict, repr=False @@ -83,18 +85,20 @@ def __post_init__(self) -> None: def __validate_ports(self) -> None: model1 = self.port1.component.model - model2 = self.port2.component.model port_1 = model1.ports.get(self.port1.port_id) + + model2 = self.port2.component.model port_2 = model2.ports.get(self.port2.port_id) if port_1 is None or port_2 is None: raise ValueError(f"Missing port: {port_1} or {port_2} ") + if port_1.port_type != port_2.port_type: raise ValueError( f"Incompatible portTypes {port_1.port_type} != {port_2.port_type}" ) - for field_name in [f.name for f in port_1.port_type.fields]: + for field_name in (f.name for f in port_1.port_type.fields): def1: bool = ( PortFieldId(port_name=port_1.port_name, field_name=field_name) in model1.port_fields_definitions @@ -103,10 +107,12 @@ def __validate_ports(self) -> None: PortFieldId(port_name=port_2.port_name, field_name=field_name) in model2.port_fields_definitions ) + if not def1 and not def2: raise ValueError( f"No definition for port field {field_name} on {port_1.port_name}." ) + if def1 and def2: raise ValueError( f"Port field {field_name} on {port_1.port_name} has 2 definitions." @@ -180,12 +186,7 @@ def all_components(self) -> Iterable[Component]: return itertools.chain(self.nodes, self.components) def connect(self, port1: PortRef, port2: PortRef) -> None: - ports_connection = PortsConnection(port1, port2) - self._connections.append(ports_connection) - - def connect2(self, port1: PortRef, parent: "Network", port2: PortRef) -> None: - ports_connection = PortsConnection(port1, port2) - self._connections.append(ports_connection) + self._connections.append(build_ports_connection(port1, port2)) @property def connections(self) -> Iterable[PortsConnection]: @@ -210,3 +211,12 @@ def replicate(self, /, **changes: Any) -> "Network": replica._connections.append(connection.replicate()) return replica + + +def build_ports_connection( + port1: PortRef, + port2: PortRef, + dt_node1: Optional[str] = None, + dt_node2: Optional[str] = None, +) -> PortsConnection: + return PortsConnection(dt_node1, port1, dt_node2, port2) diff --git a/tests/functional/test_investment_pathway.py b/tests/functional/test_investment_pathway.py index 64b9f35..b07cc6d 100644 --- a/tests/functional/test_investment_pathway.py +++ b/tests/functional/test_investment_pathway.py @@ -243,12 +243,6 @@ def test_investment_pathway_on_sequential_nodes( PortRef(candidate_chd, "balance_port"), PortRef(node, "balance_port") ) - network_chd.connect2( - PortRef(candidate_chd, "pathway_port_receive"), - network_par, - PortRef(candidate_par, "pathway_port_send"), - ) - # === Decision tree creation === config = InterDecisionTimeScenarioConfig([TimeBlock(0, [0])], 1) @@ -259,6 +253,10 @@ def test_investment_pathway_on_sequential_nodes( # === Coupling model === # decision_tree_par.add_coupling_component(candidate, "invested_capa", "delta_invest") + decision_tree_chd.connect_from_parent( + PortRef(candidate_chd, "pathway_port_receive"), + PortRef(candidate_par, "pathway_port_send"), + ) # === Build problem === xpansion = build_benders_decomposed_problem(decision_tree_par, database)