Skip to content

Commit

Permalink
Reference-to-View pass and comprehensive reference test suite (#1485)
Browse files Browse the repository at this point in the history
      Implements a reference-to-view pass (converting references to views if
they are only set to one particular subset). Also improves the simplify
pipeline in the presence of Reference data descriptors and adds multiple
tests that use references.
  • Loading branch information
tbennun authored Dec 21, 2023
1 parent 7c06755 commit 509ee0f
Show file tree
Hide file tree
Showing 9 changed files with 864 additions and 30 deletions.
3 changes: 3 additions & 0 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,9 @@ def visit_Name(self, node: ast.Name):
except KeyError:
defined_type = None
if (self.allow_casts and isinstance(dtype, dtypes.pointer) and memlet.subset.num_elements() == 1):
# Special case for pointer to pointer assignment
if memlet.data in self.sdfg.arrays and self.sdfg.arrays[memlet.data].dtype == dtype:
return self.generic_visit(node)
return ast.parse(f"{name}[0]").body[0].value
elif (self.allow_casts and (defined_type in (DefinedType.Stream, DefinedType.StreamArray))
and memlet.dynamic):
Expand Down
39 changes: 25 additions & 14 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,10 +717,8 @@ def _emit_copy(

state_dfg = sdfg.nodes()[state_id]

copy_shape, src_strides, dst_strides, src_expr, dst_expr = \
cpp.memlet_copy_to_absolute_strides(
self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node,
self._packed_types)
copy_shape, src_strides, dst_strides, src_expr, dst_expr = cpp.memlet_copy_to_absolute_strides(
self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, self._packed_types)

# Which numbers to include in the variable argument part
dynshape, dynsrc, dyndst = 1, 1, 1
Expand Down Expand Up @@ -904,7 +902,8 @@ def process_out_memlets(self,
_, uconn, v, _, memlet = edge
if skip_wcr and memlet.wcr is not None:
continue
dst_node = dfg.memlet_path(edge)[-1].dst
dst_edge = dfg.memlet_path(edge)[-1]
dst_node = dst_edge.dst

# Target is neither a data nor a tasklet node
if isinstance(node, nodes.AccessNode) and (not isinstance(dst_node, nodes.AccessNode)
Expand Down Expand Up @@ -952,9 +951,12 @@ def process_out_memlets(self,

conntype = node.out_connectors[uconn]
is_scalar = not isinstance(conntype, dtypes.pointer)
if isinstance(conntype, dtypes.pointer) and sdfg.arrays[memlet.data].dtype == conntype:
is_scalar = True # Pointer to pointer assignment
is_stream = isinstance(sdfg.arrays[memlet.data], data.Stream)
is_refset = isinstance(sdfg.arrays[memlet.data], data.Reference) and dst_edge.dst_conn == 'set'

if is_scalar and not memlet.dynamic and not is_stream:
if (is_scalar and not memlet.dynamic and not is_stream) or is_refset:
out_local_name = " __" + uconn
in_local_name = uconn
if not locals_defined:
Expand Down Expand Up @@ -987,6 +989,9 @@ def process_out_memlets(self,
if defined_type == DefinedType.Scalar:
mname = cpp.ptr(memlet.data, desc, sdfg, self._frame)
write_expr = f"{mname} = {in_local_name};"
elif defined_type == DefinedType.Pointer and is_refset:
mname = cpp.ptr(memlet.data, desc, sdfg, self._frame)
write_expr = f"{mname} = {in_local_name};"
elif (defined_type == DefinedType.ArrayInterface and not isinstance(desc, data.View)):
# Special case: No need to write anything between
# array interfaces going out
Expand Down Expand Up @@ -1473,15 +1478,21 @@ def define_out_memlet(self, sdfg, state_dfg, state_id, src_node, dst_node, edge,
cdtype = src_node.out_connectors[edge.src_conn]
if isinstance(sdfg.arrays[edge.data.data], data.Stream):
pass
elif isinstance(cdtype, dtypes.pointer):
# If pointer, also point to output
elif isinstance(cdtype, dtypes.pointer): # If pointer, also point to output
desc = sdfg.arrays[edge.data.data]
ptrname = cpp.ptr(edge.data.data, desc, sdfg, self._frame)
is_global = desc.lifetime in (dtypes.AllocationLifetime.Global, dtypes.AllocationLifetime.Persistent,
dtypes.AllocationLifetime.External)
defined_type, _ = self._dispatcher.defined_vars.get(ptrname, is_global=is_global)
base_ptr = cpp.cpp_ptr_expr(sdfg, edge.data, defined_type, codegen=self._frame)
callsite_stream.write(f'{cdtype.ctype} {edge.src_conn} = {base_ptr};', sdfg, state_id, src_node)

# If reference set, do not emit initial assignment
is_refset = isinstance(desc, data.Reference) and state_dfg.memlet_path(edge)[-1].dst_conn == 'set'

if not is_refset and not isinstance(desc.dtype, dtypes.pointer):
ptrname = cpp.ptr(edge.data.data, desc, sdfg, self._frame)
is_global = desc.lifetime in (dtypes.AllocationLifetime.Global, dtypes.AllocationLifetime.Persistent,
dtypes.AllocationLifetime.External)
defined_type, _ = self._dispatcher.defined_vars.get(ptrname, is_global=is_global)
base_ptr = cpp.cpp_ptr_expr(sdfg, edge.data, defined_type, codegen=self._frame)
callsite_stream.write(f'{cdtype.ctype} {edge.src_conn} = {base_ptr};', sdfg, state_id, src_node)
else:
callsite_stream.write(f'{cdtype.as_arg(edge.src_conn)};', sdfg, state_id, src_node)
else:
callsite_stream.write(f'{cdtype.ctype} {edge.src_conn};', sdfg, state_id, src_node)

Expand Down
8 changes: 7 additions & 1 deletion dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,15 @@ def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[
# 2. Check for reference sets
if isinstance(e.dst, dace.nodes.AccessNode) and e.dst_conn == 'set':
assert isinstance(e.dst.desc(sdfg), dace.data.Reference)

# Determine source
if isinstance(mtree.root().edge.src, dace.nodes.CodeNode):
src_desc = mtree.root().edge.src
else:
src_desc = sdfg.arrays[e.data.data]
result[e] = tn.RefSetNode(target=e.dst.data,
memlet=e.data,
src_desc=sdfg.arrays[e.data.data],
src_desc=src_desc,
ref_desc=sdfg.arrays[e.dst.data])
scope = state.entry_node(e.dst if mtree.downwards else e.src)
scope_to_edges[scope].append(e)
Expand Down
7 changes: 4 additions & 3 deletions dace/sdfg/analysis/schedule_tree/treenodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class ScheduleTreeScope(ScheduleTreeNode):
containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False)
symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False)

def __init__(self,
children: Optional[List['ScheduleTreeNode']] = None):
def __init__(self, children: Optional[List['ScheduleTreeNode']] = None):
self.children = children or []
if self.children:
for child in children:
Expand Down Expand Up @@ -350,10 +349,12 @@ class RefSetNode(ScheduleTreeNode):
"""
target: str
memlet: Memlet
src_desc: data.Data
src_desc: Union[data.Data, nodes.CodeNode]
ref_desc: data.Data

def as_string(self, indent: int = 0):
if isinstance(self.src_desc, nodes.CodeNode):
return indent * INDENTATION + f'{self.target} = refset from {type(self.src_desc).__name__.lower()}'
return indent * INDENTATION + f'{self.target} = refset to {self.memlet}'


Expand Down
21 changes: 14 additions & 7 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG',
in_default_scope = False
if sdfg.parent_nsdfg_node is not None:
if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node,
[dtypes.ScheduleType.Default]):
[dtypes.ScheduleType.Default]):
in_default_scope = True
if in_default_scope is False:
eid = region.edge_id(edge)
Expand All @@ -126,9 +126,9 @@ def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG',
if start_block not in visited:
if isinstance(start_block, SDFGState):
validate_state(start_block, region.node_id(start_block), sdfg, symbols, initialized_transients, references,
**context)
**context)
else:
validate_control_flow_region(sdfg, start_block, initialized_transients, symbols, references, **context)
validate_control_flow_region(sdfg, start_block, initialized_transients, symbols, references, **context)

# Validate all inter-state edges (including self-loops not found by DFS)
for eid, edge in enumerate(region.edges()):
Expand Down Expand Up @@ -162,7 +162,7 @@ def validate_control_flow_region(sdfg: 'dace.sdfg.SDFG',
in_default_scope = False
if sdfg.parent_nsdfg_node is not None:
if is_in_scope(sdfg.parent_sdfg, sdfg.parent, sdfg.parent_nsdfg_node,
[dtypes.ScheduleType.Default]):
[dtypes.ScheduleType.Default]):
in_default_scope = True
if in_default_scope is False:
raise InvalidSDFGInterstateEdgeError(
Expand Down Expand Up @@ -453,9 +453,16 @@ def validate_state(state: 'dace.sdfg.SDFGState',

# Find uninitialized transients
if node.data not in initialized_transients:
if (arr.transient and state.in_degree(node) == 0 and state.out_degree(node) > 0
# Streams do not need to be initialized
and not isinstance(arr, dt.Stream)):
if isinstance(arr, dt.Reference): # References are considered more conservatively
if any(e.dst_conn == 'set' for e in state.in_edges(node)):
initialized_transients.add(node.data)
else:
raise InvalidSDFGNodeError(
'Reference data descriptor was used before it was set. Set '
'it with an incoming memlet to the "set" connector', sdfg, state_id, nid)
elif (arr.transient and state.in_degree(node) == 0 and state.out_degree(node) > 0
# Streams do not need to be initialized
and not isinstance(arr, dt.Stream)):
if node.setzero == False:
warnings.warn('WARNING: Use of uninitialized transient "%s" in state %s' %
(node.data, state.label))
Expand Down
9 changes: 9 additions & 0 deletions dace/transformation/passes/dead_dataflow_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def _is_node_dead(self, node: nodes.Node, sdfg: SDFG, state: SDFGState, dead_nod

# Check incoming edges
for e in state.in_edges(node):
# A reference set should not be removed
if e.dst_conn == 'set':
return False

for l in state.memlet_tree(e).leaves():
# If data is connected to a side-effect tasklet/library node, cannot remove
if _has_side_effects(l.src, sdfg):
Expand All @@ -245,6 +249,11 @@ def _is_node_dead(self, node: nodes.Node, sdfg: SDFG, state: SDFGState, dead_nod
if isinstance(desc, data.Stream) and node.data in access_set[0]:
return False

# If it is a reference, it may point to other data containers,
# be conservative for now
if isinstance(desc, data.Reference):
return False

# Any other case can be marked as dead
return True

Expand Down
Loading

0 comments on commit 509ee0f

Please sign in to comment.