Skip to content

Commit

Permalink
Better Name Validation (#1661)
Browse files Browse the repository at this point in the history
This PR adds checks to the SDFG to ensures that names from symbols, data
descriptors and so on are unique.
Furthermore, it also ensures that the NestedSDFG validates correctly and
ensures that no symbols can be written.
  • Loading branch information
philip-paul-mueller authored Sep 24, 2024
1 parent d0dcf1c commit 7df09c7
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 91 deletions.
21 changes: 0 additions & 21 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,27 +136,6 @@ def create_datadescriptor(obj, no_custom_desc=False):
'adaptor method to the type hint or object itself.')


def find_new_name(name: str, existing_names: Sequence[str]) -> str:
"""
Returns a name that matches the given ``name`` as a prefix, but does not
already exist in the given existing name set. The behavior is typically
to append an underscore followed by a unique (increasing) number. If the
name does not already exist in the set, it is returned as-is.
:param name: The given name to find.
:param existing_names: The set of existing names.
:return: A new name that is not in existing_names.
"""
if name not in existing_names:
return name
cur_offset = 0
new_name = name + '_' + str(cur_offset)
while new_name in existing_names:
cur_offset += 1
new_name = name + '_' + str(cur_offset)
return new_name


def _prod(sequence):
return functools.reduce(lambda a, b: a * b, sequence, 1)

Expand Down
26 changes: 13 additions & 13 deletions dace/frontend/common/distr.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def _cart_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, dims: Shape
state.add_node(tasklet)

# Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations.
_, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True)
wnode = state.add_write(pgrid_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal))
scal_name, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True, find_new_name=True)
wnode = state.add_write(scal_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal))

return pgrid_name

Expand Down Expand Up @@ -97,9 +97,9 @@ def _cart_sub(pv: 'ProgramVisitor',
state.add_node(tasklet)

# Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations.
_, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True)
wnode = state.add_write(pgrid_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal))
scal_name, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True, find_new_name=True)
wnode = state.add_write(scal_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal))

return pgrid_name

Expand Down Expand Up @@ -196,7 +196,7 @@ def _intracomm_bcast(pv: 'ProgramVisitor',
if comm_obj == MPI.COMM_WORLD:
return _bcast(pv, sdfg, state, buffer, root)
# NOTE: Highly experimental
sdfg.add_scalar(comm_name, dace.int32)
scal_name, _ = sdfg.add_scalar(comm_name, dace.int32, find_new_name=True)
return _bcast(pv, sdfg, state, buffer, root, fcomm=comm_name)


Expand Down Expand Up @@ -941,9 +941,9 @@ def _subarray(pv: ProgramVisitor,
state.add_node(tasklet)

# Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations.
_, scal = sdfg.add_scalar(subarray_name, dace.int32, transient=True)
wnode = state.add_write(subarray_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(subarray_name, scal))
scal_name, scal = sdfg.add_scalar(subarray_name, dace.int32, transient=True, find_new_name=True)
wnode = state.add_write(scal_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal))

return subarray_name

Expand Down Expand Up @@ -1078,9 +1078,9 @@ def _redistribute(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, in_buffer: s
f'int* {rdistrarray_name}_self_size;'
])
state.add_node(tasklet)
_, scal = sdfg.add_scalar(rdistrarray_name, dace.int32, transient=True)
wnode = state.add_write(rdistrarray_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(rdistrarray_name, scal))
scal_name, scal = sdfg.add_scalar(rdistrarray_name, dace.int32, transient=True, find_new_name=True)
wnode = state.add_write(scal_name)
state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal))

libnode = Redistribute('_Redistribute_', rdistrarray_name)

Expand Down
11 changes: 10 additions & 1 deletion dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3302,6 +3302,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
tokens = name.split('.')
name = tokens[0]
true_name = None
true_array = None
if name in defined_vars:
true_name = defined_vars[name]
if len(tokens) > 1:
Expand Down Expand Up @@ -3356,7 +3357,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
new_data, rng = None, None
dtype_keys = tuple(dtypes.dtype_to_typeclass().keys())
if not (result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or
(isinstance(result, str) and result in self.sdfg.arrays)):
(isinstance(result, str) and any(result in x for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))):
raise DaceSyntaxError(
self, node, "In assignments, the rhs may only be "
"data, numerical/boolean constants "
Expand All @@ -3380,6 +3381,14 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
_, new_data = self.sdfg.add_scalar(true_name, ttype, transient=True)
self.variables[name] = true_name
defined_vars[name] = true_name
if any(result in x for x in [self.sdfg._pgrids, self.sdfg._rdistrarrays, self.sdfg._subarrays]):
# NOTE: In previous versions some `pgrid` and subgrid related replacement function,
# see `dace/frontend/common/distr.py`, created dummy variables with the same name
# as the entities, such as process grids, they created. Thus the frontend was
# finding them. Since this is now disallowed, we have to explicitly handle this case.
self.variables[name] = result
defined_vars[name] = result
continue
elif isinstance(result, str) and result in self.sdfg.arrays:
result_data = self.sdfg.arrays[result]
if (name.startswith('__return') and isinstance(result_data, data.Scalar)):
Expand Down
9 changes: 6 additions & 3 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:
internally_used_symbols = self.sdfg.used_symbols(all_symbols=False)
keys_to_use &= internally_used_symbols

# Translate the internal symbols back to their external counterparts.
free_syms |= set().union(*(map(str,
pystr_to_symbolic(v).free_symbols) for k, v in self.symbol_mapping.items()
if k in keys_to_use))
Expand Down Expand Up @@ -662,6 +663,10 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context

connectors = self.in_connectors.keys() | self.out_connectors.keys()
for conn in connectors:
if conn in self.sdfg.symbols:
raise ValueError(
f'Connector "{conn}" was given, but it refers to a symbol, which is not allowed. '
'To pass symbols use "symbol_mapping".')
if conn not in self.sdfg.arrays:
raise NameError(
f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. '
Expand Down Expand Up @@ -795,10 +800,8 @@ def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]:
for p, rng in zip(self._map.params, self._map.range):
result[p] = dtypes.result_type_of(infer_expr_type(rng[0], symbols), infer_expr_type(rng[1], symbols))

# Add dynamic inputs
# Handle the dynamic map ranges.
dyn_inputs = set(c for c in self.in_connectors if not c.startswith('IN_'))

# Try to get connector type from connector
for e in state.in_edges(self):
if e.dst_conn in dyn_inputs:
result[e.dst_conn] = (self.in_connectors[e.dst_conn] or sdfg.arrays[e.data.data].dtype)
Expand Down
159 changes: 106 additions & 53 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,17 +746,32 @@ def replace_dict(self,

super().replace_dict(repldict, symrepl, replace_in_graph, replace_keys)

def add_symbol(self, name, stype):
def add_symbol(self, name, stype, find_new_name: bool = False):
""" Adds a symbol to the SDFG.
:param name: Symbol name.
:param stype: Symbol type.
:param find_new_name: Find a new name.
"""
if name in self.symbols:
raise FileExistsError('Symbol "%s" already exists in SDFG' % name)
if find_new_name:
name = self._find_new_name(name)
else:
# We do not check for data constant, because there is a link between the constants and
# the data descriptors.
if name in self.symbols:
raise FileExistsError(f'Symbol "{name}" already exists in SDFG')
if name in self.arrays:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a data descriptor.')
if name in self._subarrays:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a subarray.')
if name in self._rdistrarrays:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a RedistrArray.')
if name in self._pgrids:
raise FileExistsError(f'Can not create symbol "{name}", the name is used by a ProcessGrid.')
if not isinstance(stype, dtypes.typeclass):
stype = dtypes.dtype_to_typeclass(stype)
self.symbols[name] = stype
return name

def remove_symbol(self, name):
""" Removes a symbol from the SDFG.
Expand Down Expand Up @@ -1159,14 +1174,23 @@ def cast(dtype: dt.Data, value: Any):
return result

def add_constant(self, name: str, value: Any, dtype: dt.Data = None):
""" Adds/updates a new compile-time constant to this SDFG. A constant
may either be a scalar or a numpy ndarray thereof.
"""
Adds/updates a new compile-time constant to this SDFG.
:param name: The name of the constant.
:param value: The constant value.
:param dtype: Optional data type of the symbol, or None to deduce
automatically.
A constant may either be a scalar or a numpy ndarray thereof. It is not an
error if there is already a symbol or an array with the same name inside
the SDFG. However, the data descriptors must refer to the same type.
:param name: The name of the constant.
:param value: The constant value.
:param dtype: Optional data type of the symbol, or None to deduce automatically.
"""
if name in self._subarrays:
raise FileExistsError(f'Can not create constant "{name}", the name is used by a subarray.')
if name in self._rdistrarrays:
raise FileExistsError(f'Can not create constant "{name}", the name is used by a RedistrArray.')
if name in self._pgrids:
raise FileExistsError(f'Can not create constant "{name}", the name is used by a ProcessGrid.')
self.constants_prop[name] = (dtype or dt.create_datadescriptor(value), value)

@property
Expand Down Expand Up @@ -1598,36 +1622,44 @@ def _find_new_name(self, name: str):
""" Tries to find a new name by adding an underscore and a number. """

names = (self._arrays.keys() | self.constants_prop.keys() | self._pgrids.keys() | self._subarrays.keys()
| self._rdistrarrays.keys())
| self._rdistrarrays.keys() | self.symbols.keys())
return dt.find_new_name(name, names)

def is_name_used(self, name: str) -> bool:
""" Checks if `name` is already used inside the SDFG."""
if name in self._arrays:
return True
if name in self.symbols:
return True
if name in self.constants_prop:
return True
if name in self._pgrids:
return True
if name in self._subarrays:
return True
if name in self._rdistrarrays:
return True
return False

def is_name_free(self, name: str) -> bool:
""" Test if `name` is free, i.e. is not used by anything else."""
return not self.is_name_used(name)

def find_new_constant(self, name: str):
"""
Tries to find a new constant name by adding an underscore and a number.
Tries to find a new name for a constant.
"""
constants = self.constants
if name not in constants:
if self.is_name_free(name):
return name

index = 0
while (name + ('_%d' % index)) in constants:
index += 1

return name + ('_%d' % index)
return self._find_new_name(name)

def find_new_symbol(self, name: str):
"""
Tries to find a new symbol name by adding an underscore and a number.
"""
symbols = self.symbols
if name not in symbols:
if self.is_name_free(name):
return name

index = 0
while (name + ('_%d' % index)) in symbols:
index += 1

return name + ('_%d' % index)
return self._find_new_name(name)

def add_array(self,
name: str,
Expand Down Expand Up @@ -1856,13 +1888,14 @@ def add_transient(self,

def temp_data_name(self):
""" Returns a temporary data descriptor name that can be used in this SDFG. """

name = '__tmp%d' % self._temp_transients
while name in self._arrays:

# NOTE: Consider switching to `_find_new_name`
# The frontend seems to access this variable directly.
while self.is_name_used(name):
self._temp_transients += 1
name = '__tmp%d' % self._temp_transients
self._temp_transients += 1

return name

def add_temp_transient(self,
Expand Down Expand Up @@ -1917,29 +1950,47 @@ def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str
"""
if not isinstance(name, str):
raise TypeError("Data descriptor name must be a string. Got %s" % type(name).__name__)
# If exists, fail
while name in self._arrays:
if find_new_name:
name = self._find_new_name(name)
else:
raise NameError(f'Array or Stream with name "{name}" already exists in SDFG')
# NOTE: Remove illegal characters, such as dots. Such characters may be introduced when creating views to
# members of Structures.
name = name.replace('.', '_')
assert name not in self._arrays
self._arrays[name] = datadesc

def _add_symbols(desc: dt.Data):
if find_new_name:
# These characters might be introduced through the creation of views to members
# of strictures.
# NOTES: If `find_new_name` is `True` and the name (understood as a sequence of
# any characters) is not used, i.e. `assert self.is_name_free(name)`, then it
# is still "cleaned", i.e. dots are replaced with underscores. However, if
# `find_new_name` is `False` then this cleaning is not applied and it is possible
# to create names that are formally invalid. The above code reproduces the exact
# same behaviour and is maintained for compatibility. This behaviour is
# triggered by tests/python_frontend/structures/structure_python_test.py::test_rgf`.
name = self._find_new_name(name)
name = name.replace('.', '_')
if self.is_name_used(name):
name = self._find_new_name(name)
else:
# We do not check for data constant, because there is a link between the constants and
# the data descriptors.
if name in self.arrays:
raise FileExistsError(f'Data descriptor "{name}" already exists in SDFG')
if name in self.symbols:
raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a symbol.')
if name in self._subarrays:
raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a subarray.')
if name in self._rdistrarrays:
raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a RedistrArray.')
if name in self._pgrids:
raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a ProcessGrid.')

def _add_symbols(sdfg: SDFG, desc: dt.Data):
if isinstance(desc, dt.Structure):
for v in desc.members.values():
if isinstance(v, dt.Data):
_add_symbols(v)
_add_symbols(sdfg, v)
for sym in desc.free_symbols:
if sym.name not in self.symbols:
self.add_symbol(sym.name, sym.dtype)
if sym.name not in sdfg.symbols:
sdfg.add_symbol(sym.name, sym.dtype)

# Add free symbols to the SDFG global symbol storage
_add_symbols(datadesc)
# Add the data descriptor to the SDFG and all symbols that are not yet known.
self._arrays[name] = datadesc
_add_symbols(self, datadesc)

return name

Expand Down Expand Up @@ -2044,9 +2095,10 @@ def add_subarray(self,
newshape.append(dace.symbolic.pystr_to_symbolic(s))
subshape = newshape

# No need to ensure unique test.
subarray_name = self._find_new_name('__subarray')
self._subarrays[subarray_name] = SubArray(subarray_name, dtype, shape, subshape, pgrid, correspondence)

self._subarrays[subarray_name] = SubArray(subarray_name, dtype, shape, subshape, pgrid, correspondence)
self.append_init_code(self._subarrays[subarray_name].init_code())
self.append_exit_code(self._subarrays[subarray_name].exit_code())

Expand All @@ -2060,12 +2112,13 @@ def add_rdistrarray(self, array_a: str, array_b: str):
:param array_b: Output sub-array descriptor.
:return: Name of the new redistribution descriptor.
"""
# No need to ensure unique test.
name = self._find_new_name('__rdistrarray')

rdistrarray_name = self._find_new_name('__rdistrarray')
self._rdistrarrays[rdistrarray_name] = RedistrArray(rdistrarray_name, array_a, array_b)
self.append_init_code(self._rdistrarrays[rdistrarray_name].init_code(self))
self.append_exit_code(self._rdistrarrays[rdistrarray_name].exit_code(self))
return rdistrarray_name
self._rdistrarrays[name] = RedistrArray(name, array_a, array_b)
self.append_init_code(self._rdistrarrays[name].init_code(self))
self.append_exit_code(self._rdistrarrays[name].exit_code(self))
return name

def add_loop(
self,
Expand Down
Loading

0 comments on commit 7df09c7

Please sign in to comment.