Skip to content

Commit

Permalink
Add external nested SDFG capabilities (#1795)
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad authored Dec 2, 2024
1 parent 77c5c72 commit 05e5908
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 83 deletions.
100 changes: 56 additions & 44 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ class NestedSDFG(CodeNode):

# NOTE: We cannot use SDFG as the type because of an import loop
sdfg = SDFGReferenceProperty(desc="The SDFG", allow_none=True)
ext_sdfg_path = Property(dtype=str, default=None, allow_none=True,
desc='Path to a file containing the SDFG for this nested SDFG')
schedule = EnumProperty(dtype=dtypes.ScheduleType,
desc="SDFG schedule",
allow_none=True,
Expand All @@ -569,22 +571,30 @@ class NestedSDFG(CodeNode):

def __init__(self,
label,
sdfg,
sdfg: Optional['dace.SDFG'],
inputs: Set[str],
outputs: Set[str],
symbol_mapping: Dict[str, Any] = None,
schedule=dtypes.ScheduleType.Default,
location=None,
debuginfo=None):
from dace.sdfg import SDFG
debuginfo=None,
path: Optional[str] = None):
super(NestedSDFG, self).__init__(label, location, inputs, outputs)

# Properties
self.sdfg: SDFG = sdfg
self.sdfg: 'dace.SDFG' = sdfg
self.ext_sdfg_path = path
self.symbol_mapping = symbol_mapping or {}
self.schedule = schedule
self.debuginfo = debuginfo

def load_external(self, context: Optional['dace.SDFGState']) -> None:
if self.sdfg is None and self.ext_sdfg_path is not None:
self.sdfg = dace.SDFG.from_file(self.ext_sdfg_path)
self.sdfg.parent_nsdfg_node = self
self.sdfg.parent = context
self.sdfg.parent_sdfg = context.sdfg if context else None

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
Expand All @@ -607,14 +617,14 @@ def from_json(json_obj, context=None):

dace.serialize.set_properties_from_json(ret, json_obj, context)

if context and 'sdfg_state' in context:
ret.sdfg.parent = context['sdfg_state']
if context and 'sdfg' in context:
ret.sdfg.parent_sdfg = context['sdfg']

ret.sdfg.parent_nsdfg_node = ret
if ret.sdfg is not None:
if context and 'sdfg_state' in context:
ret.sdfg.parent = context['sdfg_state']
if context and 'sdfg' in context:
ret.sdfg.parent_sdfg = context['sdfg']
ret.sdfg.parent_nsdfg_node = ret

ret.sdfg.update_cfg_list([])
ret.sdfg.update_cfg_list([])

return ret

Expand Down Expand Up @@ -664,28 +674,29 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
for out_conn in self.out_connectors:
if not dtypes.validate_name(out_conn):
raise NameError('Invalid output connector "%s"' % out_conn)
if self.sdfg.parent_nsdfg_node is not self:
raise ValueError('Parent nested SDFG node not properly set')
if self.sdfg.parent is not state:
raise ValueError('Parent state not properly set for nested SDFG node')
if self.sdfg.parent_sdfg is not sdfg:
raise ValueError('Parent SDFG not properly set for nested SDFG node')

connectors = self.in_connectors.keys() | self.out_connectors.keys()
for conn in connectors:
if conn in self.sdfg.symbols:
raise ValueError(
f'Connector "{conn}" was given, but it refers to a symbol, which is not allowed. '
'To pass symbols use "symbol_mapping".')
if conn not in self.sdfg.arrays:
raise NameError(
f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. '
'Example: parameter passed to a function without a matching array within it.')
for dname, desc in self.sdfg.arrays.items():
if not desc.transient and dname not in connectors:
raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname)
if dname in connectors and desc.transient:
raise NameError('"%s" is a connector but its corresponding array is transient' % dname)
if self.sdfg:
if self.sdfg.parent_nsdfg_node is not self:
raise ValueError('Parent nested SDFG node not properly set')
if self.sdfg.parent is not state:
raise ValueError('Parent state not properly set for nested SDFG node')
if self.sdfg.parent_sdfg is not sdfg:
raise ValueError('Parent SDFG not properly set for nested SDFG node')

connectors = self.in_connectors.keys() | self.out_connectors.keys()
for conn in connectors:
if conn in self.sdfg.symbols:
raise ValueError(
f'Connector "{conn}" was given, but it refers to a symbol, which is not allowed. '
'To pass symbols use "symbol_mapping".')
if conn not in self.sdfg.arrays:
raise NameError(
f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. '
'Example: parameter passed to a function without a matching array within it.')
for dname, desc in self.sdfg.arrays.items():
if not desc.transient and dname not in connectors:
raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname)
if dname in connectors and desc.transient:
raise NameError('"%s" is a connector but its corresponding array is transient' % dname)

# Validate inout connectors
from dace.sdfg import utils # Avoids circular import
Expand All @@ -706,17 +717,18 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context
f"output ({outputs}) arrays")

# Validate undefined symbols
symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG: %s' % (missing_symbols))
extra_symbols = self.symbol_mapping.keys() - symbols
if len(extra_symbols) > 0:
# TODO: Elevate to an error?
warnings.warn(f"{self.label} maps to unused symbol(s): {extra_symbols}")

# Recursively validate nested SDFG
self.sdfg.validate(references, **context)
if self.sdfg:
symbols = set(k for k in self.sdfg.free_symbols if k not in connectors)
missing_symbols = [s for s in symbols if s not in self.symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG: %s' % (missing_symbols))
extra_symbols = self.symbol_mapping.keys() - symbols
if len(extra_symbols) > 0:
# TODO: Elevate to an error?
warnings.warn(f"{self.label} maps to unused symbol(s): {extra_symbols}")

# Recursively validate nested SDFG
self.sdfg.validate(references, **context)


# ------------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2341,6 +2341,10 @@ def compile(self, output_file=None, validate=True,
# if the codegen modifies the SDFG (thereby changing its hash)
sdfg.build_folder = build_folder

# Ensure external nested SDFGs are loaded.
for _ in sdfg.all_sdfgs_recursive(load_ext=True):
pass

# Rename SDFG to avoid runtime issues with clashing names
index = 0
while sdfg.is_loaded():
Expand Down
85 changes: 48 additions & 37 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def edges(self) -> List[MultiConnectorEdge[mm.Memlet]]:
def all_nodes_recursive(self, predicate = None) -> Iterator[Tuple[NodeT, GraphT]]:
for node in self.nodes():
yield node, self
if isinstance(node, nd.NestedSDFG):
if isinstance(node, nd.NestedSDFG) and node.sdfg:
if predicate is None or predicate(node, self):
yield from node.sdfg.all_nodes_recursive(predicate)

Expand Down Expand Up @@ -1380,7 +1380,7 @@ def add_node(self, node):
if not isinstance(node, nd.Node):
raise TypeError("Expected Node, got " + type(node).__name__ + " (" + str(node) + ")")
# Correct nested SDFG's parent attributes
if isinstance(node, nd.NestedSDFG):
if isinstance(node, nd.NestedSDFG) and node.sdfg is not None:
node.sdfg.parent = self
node.sdfg.parent_sdfg = self.sdfg
node.sdfg.parent_nsdfg_node = node
Expand Down Expand Up @@ -1667,7 +1667,7 @@ def add_tasklet(

def add_nested_sdfg(
self,
sdfg: 'SDFG',
sdfg: Optional['SDFG'],
parent,
inputs: Union[Set[str], Dict[str, dtypes.typeclass]],
outputs: Union[Set[str], Dict[str, dtypes.typeclass]],
Expand All @@ -1676,16 +1676,21 @@ def add_nested_sdfg(
schedule=dtypes.ScheduleType.Default,
location=None,
debuginfo=None,
external_path: Optional[str] = None,
):
""" Adds a nested SDFG to the SDFG state. """
if name is None:
name = sdfg.label
debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo)

sdfg.parent = self
sdfg.parent_sdfg = self.sdfg
if sdfg is None and external_path is None:
raise ValueError('Neither an SDFG nor an external SDFG path has been provided')

if sdfg is not None:
sdfg.parent = self
sdfg.parent_sdfg = self.sdfg

sdfg.update_cfg_list([])
sdfg.update_cfg_list([])

# Make dictionary of autodetect connector types from set
if isinstance(inputs, (set, collections.abc.KeysView)):
Expand All @@ -1702,35 +1707,37 @@ def add_nested_sdfg(
schedule=schedule,
location=location,
debuginfo=debuginfo,
path=external_path,
)
self.add_node(s)

sdfg.parent_nsdfg_node = s

# Add "default" undefined symbols if None are given
symbols = sdfg.free_symbols
if symbol_mapping is None:
symbol_mapping = {s: s for s in symbols}
s.symbol_mapping = symbol_mapping

# Validate missing symbols
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols and parent:
# If symbols are missing, try to get them from the parent SDFG
parent_mapping = {s: s for s in missing_symbols if s in parent.symbols}
symbol_mapping.update(parent_mapping)
s.symbol_mapping = symbol_mapping
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG "%s": %s' % (name, missing_symbols))
if sdfg is not None:
sdfg.parent_nsdfg_node = s

# Add new global symbols to nested SDFG
from dace.codegen.tools.type_inference import infer_expr_type
for sym, symval in s.symbol_mapping.items():
if sym not in sdfg.symbols:
# TODO: Think of a better way to avoid calling
# symbols_defined_at in this moment
sdfg.add_symbol(sym, infer_expr_type(symval, self.sdfg.symbols) or dtypes.typeclass(int))
# Add "default" undefined symbols if None are given
symbols = sdfg.free_symbols
if symbol_mapping is None:
symbol_mapping = {s: s for s in symbols}
s.symbol_mapping = symbol_mapping

# Validate missing symbols
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols and parent:
# If symbols are missing, try to get them from the parent SDFG
parent_mapping = {s: s for s in missing_symbols if s in parent.symbols}
symbol_mapping.update(parent_mapping)
s.symbol_mapping = symbol_mapping
missing_symbols = [s for s in symbols if s not in symbol_mapping]
if missing_symbols:
raise ValueError('Missing symbols on nested SDFG "%s": %s' % (name, missing_symbols))

# Add new global symbols to nested SDFG
from dace.codegen.tools.type_inference import infer_expr_type
for sym, symval in s.symbol_mapping.items():
if sym not in sdfg.symbols:
# TODO: Think of a better way to avoid calling
# symbols_defined_at in this moment
sdfg.add_symbol(sym, infer_expr_type(symval, self.sdfg.symbols) or dtypes.typeclass(int))

return s

Expand Down Expand Up @@ -2818,23 +2825,27 @@ def add_state_after(self,
###################################################################
# Traversal methods

def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegion']:
def all_control_flow_regions(self, recursive=False, load_ext=False) -> Iterator['ControlFlowRegion']:
""" Iterate over this and all nested control flow regions. """
yield self
for block in self.nodes():
if isinstance(block, SDFGState) and recursive:
for node in block.nodes():
if isinstance(node, nd.NestedSDFG):
yield from node.sdfg.all_control_flow_regions(recursive=recursive)
if node.sdfg:
yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext)
elif load_ext:
node.load_external(block)
yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext)
elif isinstance(block, ControlFlowRegion):
yield from block.all_control_flow_regions(recursive=recursive)
yield from block.all_control_flow_regions(recursive=recursive, load_ext=load_ext)
elif isinstance(block, ConditionalBlock):
for _, branch in block.branches:
yield from branch.all_control_flow_regions(recursive=recursive)
yield from branch.all_control_flow_regions(recursive=recursive, load_ext=load_ext)

def all_sdfgs_recursive(self) -> Iterator['SDFG']:
def all_sdfgs_recursive(self, load_ext=False) -> Iterator['SDFG']:
""" Iterate over this and all nested SDFGs. """
for cfg in self.all_control_flow_regions(recursive=True):
for cfg in self.all_control_flow_regions(recursive=True, load_ext=load_ext):
if isinstance(cfg, dace.SDFG):
yield cfg

Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/passes/fusion_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def modifies(self) -> ppl.Modifies:
def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]:
modified = 0
for node, state in sdfg.all_nodes_recursive():
if not isinstance(node, nodes.NestedSDFG):
if not isinstance(node, nodes.NestedSDFG) or node.sdfg is None:
continue
was_modified = False
if node.sdfg.parent_nsdfg_node is not node:
Expand Down
57 changes: 56 additions & 1 deletion tests/nested_sdfg_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import os
import tempfile
import numpy as np

import dace as dp
Expand Down Expand Up @@ -54,5 +56,58 @@ def do():
assert diff <= 1e-5


def test_external_nsdfg():
N = dp.symbol('N')

@dp.program
def sdfg_internal(input: dp.float32, output: dp.float32[1]):
@dp.tasklet
def init():
out >> output
out = input

for k in range(4):

@dp.tasklet
def do():
oin << output
out >> output
out = oin * input


# Construct SDFG
mysdfg = SDFG('outer_sdfg')
state = mysdfg.add_state()
A = state.add_array('A', [N, N], dp.float32)
B = state.add_array('B', [N, N], dp.float32)

map_entry, map_exit = state.add_map('elements', [('i', '0:N'), ('j', '0:N')])
internal = sdfg_internal.to_sdfg()
fd, filename = tempfile.mkstemp(suffix='.sdfg')
internal.save(filename)
nsdfg = state.add_nested_sdfg(None, mysdfg, {'input'}, {'output'}, name='sdfg_internal', external_path=filename)

# Add edges
state.add_memlet_path(A, map_entry, nsdfg, dst_conn='input', memlet=Memlet.simple(A, 'i,j'))
state.add_memlet_path(nsdfg, map_exit, B, src_conn='output', memlet=Memlet.simple(B, 'i,j'))


N = 64

input = dp.ndarray([N, N], dp.float32)
output = dp.ndarray([N, N], dp.float32)
input[:] = np.random.rand(N, N).astype(dp.float32.type)
output[:] = dp.float32(0)

mysdfg(A=input, B=output, N=N)

diff = np.linalg.norm(output - np.power(input, 5)) / (N * N)
print("Difference:", diff)
assert diff <= 1e-5

os.close(fd)


if __name__ == "__main__":
test()
test_external_nsdfg()

0 comments on commit 05e5908

Please sign in to comment.