Skip to content

Commit

Permalink
Updated the add_state_{after, before}() function. (#1556)
Browse files Browse the repository at this point in the history
It is now possible to add conditions and assignments directly to them.
Furthermore they also support now `is_start_block` flag.
  • Loading branch information
philip-paul-mueller authored Apr 10, 2024
1 parent d0db188 commit 888fd2d
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2477,38 +2477,56 @@ def add_state(self, label=None, is_start_block=False, *, is_start_state: bool=No
self.add_node(state, is_start_block=start_block)
return state

def add_state_before(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState:
def add_state_before(self,
state: SDFGState,
label=None,
is_start_block=False,
condition: CodeBlock = None,
assignments=None,
*,
is_start_state: bool=None) -> SDFGState:
""" Adds a new SDFG state before an existing state, reconnecting predecessors to it instead.
:param state: The state to prepend the new state before.
:param label: State label.
:param is_start_state: If True, resets scope block starting state to this state.
:param is_start_block: If True, resets scope block starting state to this state.
:param condition: Transition condition of the newly created edge between state and the new state.
:param assignments: Assignments to perform upon transition.
:return: A new SDFGState object.
"""
new_state = self.add_state(label, is_start_state)
new_state = self.add_state(label, is_start_block=is_start_block, is_start_state=is_start_state)
# Reconnect
for e in self.in_edges(state):
self.remove_edge(e)
self.add_edge(e.src, new_state, e.data)
# Add unconditional connection between the new state and the current
self.add_edge(new_state, state, dace.sdfg.InterstateEdge())
# Add the new edge
self.add_edge(new_state, state, dace.sdfg.InterstateEdge(condition=condition, assignments=assignments))
return new_state

def add_state_after(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState:
def add_state_after(self,
state: SDFGState,
label=None,
is_start_block=False,
condition: CodeBlock = None,
assignments=None,
*,
is_start_state: bool=None) -> SDFGState:
""" Adds a new SDFG state after an existing state, reconnecting it to the successors instead.
:param state: The state to append the new state after.
:param label: State label.
:param is_start_state: If True, resets SDFG starting state to this state.
:param is_start_block: If True, resets scope block starting state to this state.
:param condition: Transition condition of the newly created edge between state and the new state.
:param assignments: Assignments to perform upon transition.
:return: A new SDFGState object.
"""
new_state = self.add_state(label, is_start_state)
new_state = self.add_state(label, is_start_block=is_start_block, is_start_state=is_start_state)
# Reconnect
for e in self.out_edges(state):
self.remove_edge(e)
self.add_edge(new_state, e.dst, e.data)
# Add unconditional connection between the current and the new state
self.add_edge(state, new_state, dace.sdfg.InterstateEdge())
# Add the new edge
self.add_edge(state, new_state, dace.sdfg.InterstateEdge(condition=condition, assignments=assignments))
return new_state

@abc.abstractmethod
Expand Down

0 comments on commit 888fd2d

Please sign in to comment.