From 888fd2de1da370d5dff6346179af172bfa3d34bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Wed, 10 Apr 2024 09:30:28 +0200 Subject: [PATCH] Updated the `add_state_{after, before}()` function. (#1556) It is now possible to add conditions and assignments directly to them. Furthermore they also support now `is_start_block` flag. --- dace/sdfg/state.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a9f7071b0f..0a93d54c2c 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2477,38 +2477,56 @@ def add_state(self, label=None, is_start_block=False, *, is_start_state: bool=No self.add_node(state, is_start_block=start_block) return state - def add_state_before(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + def add_state_before(self, + state: SDFGState, + label=None, + is_start_block=False, + condition: CodeBlock = None, + assignments=None, + *, + is_start_state: bool=None) -> SDFGState: """ Adds a new SDFG state before an existing state, reconnecting predecessors to it instead. :param state: The state to prepend the new state before. :param label: State label. - :param is_start_state: If True, resets scope block starting state to this state. + :param is_start_block: If True, resets scope block starting state to this state. + :param condition: Transition condition of the newly created edge between state and the new state. + :param assignments: Assignments to perform upon transition. :return: A new SDFGState object. """ - new_state = self.add_state(label, is_start_state) + new_state = self.add_state(label, is_start_block=is_start_block, is_start_state=is_start_state) # Reconnect for e in self.in_edges(state): self.remove_edge(e) self.add_edge(e.src, new_state, e.data) - # Add unconditional connection between the new state and the current - self.add_edge(new_state, state, dace.sdfg.InterstateEdge()) + # Add the new edge + self.add_edge(new_state, state, dace.sdfg.InterstateEdge(condition=condition, assignments=assignments)) return new_state - def add_state_after(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + def add_state_after(self, + state: SDFGState, + label=None, + is_start_block=False, + condition: CodeBlock = None, + assignments=None, + *, + is_start_state: bool=None) -> SDFGState: """ Adds a new SDFG state after an existing state, reconnecting it to the successors instead. :param state: The state to append the new state after. :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this state. + :param is_start_block: If True, resets scope block starting state to this state. + :param condition: Transition condition of the newly created edge between state and the new state. + :param assignments: Assignments to perform upon transition. :return: A new SDFGState object. """ - new_state = self.add_state(label, is_start_state) + new_state = self.add_state(label, is_start_block=is_start_block, is_start_state=is_start_state) # Reconnect for e in self.out_edges(state): self.remove_edge(e) self.add_edge(new_state, e.dst, e.data) - # Add unconditional connection between the current and the new state - self.add_edge(state, new_state, dace.sdfg.InterstateEdge()) + # Add the new edge + self.add_edge(state, new_state, dace.sdfg.InterstateEdge(condition=condition, assignments=assignments)) return new_state @abc.abstractmethod