diff --git a/dace/data.py b/dace/data.py index 04bdc93357..a07fe42083 100644 --- a/dace/data.py +++ b/dace/data.py @@ -136,27 +136,6 @@ def create_datadescriptor(obj, no_custom_desc=False): 'adaptor method to the type hint or object itself.') -def find_new_name(name: str, existing_names: Sequence[str]) -> str: - """ - Returns a name that matches the given ``name`` as a prefix, but does not - already exist in the given existing name set. The behavior is typically - to append an underscore followed by a unique (increasing) number. If the - name does not already exist in the set, it is returned as-is. - - :param name: The given name to find. - :param existing_names: The set of existing names. - :return: A new name that is not in existing_names. - """ - if name not in existing_names: - return name - cur_offset = 0 - new_name = name + '_' + str(cur_offset) - while new_name in existing_names: - cur_offset += 1 - new_name = name + '_' + str(cur_offset) - return new_name - - def _prod(sequence): return functools.reduce(lambda a, b: a * b, sequence, 1) diff --git a/dace/frontend/common/distr.py b/dace/frontend/common/distr.py index d6f22da358..88a6b0c54a 100644 --- a/dace/frontend/common/distr.py +++ b/dace/frontend/common/distr.py @@ -42,9 +42,9 @@ def _cart_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, dims: Shape state.add_node(tasklet) # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. - _, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True) - wnode = state.add_write(pgrid_name) - state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal)) + scal_name, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True, find_new_name=True) + wnode = state.add_write(scal_name) + state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal)) return pgrid_name @@ -97,9 +97,9 @@ def _cart_sub(pv: 'ProgramVisitor', state.add_node(tasklet) # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. - _, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True) - wnode = state.add_write(pgrid_name) - state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(pgrid_name, scal)) + scal_name, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True, find_new_name=True) + wnode = state.add_write(scal_name) + state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal)) return pgrid_name @@ -196,7 +196,7 @@ def _intracomm_bcast(pv: 'ProgramVisitor', if comm_obj == MPI.COMM_WORLD: return _bcast(pv, sdfg, state, buffer, root) # NOTE: Highly experimental - sdfg.add_scalar(comm_name, dace.int32) + scal_name, _ = sdfg.add_scalar(comm_name, dace.int32, find_new_name=True) return _bcast(pv, sdfg, state, buffer, root, fcomm=comm_name) @@ -941,9 +941,9 @@ def _subarray(pv: ProgramVisitor, state.add_node(tasklet) # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. - _, scal = sdfg.add_scalar(subarray_name, dace.int32, transient=True) - wnode = state.add_write(subarray_name) - state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(subarray_name, scal)) + scal_name, scal = sdfg.add_scalar(subarray_name, dace.int32, transient=True, find_new_name=True) + wnode = state.add_write(scal_name) + state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal)) return subarray_name @@ -1078,9 +1078,9 @@ def _redistribute(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, in_buffer: s f'int* {rdistrarray_name}_self_size;' ]) state.add_node(tasklet) - _, scal = sdfg.add_scalar(rdistrarray_name, dace.int32, transient=True) - wnode = state.add_write(rdistrarray_name) - state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(rdistrarray_name, scal)) + scal_name, scal = sdfg.add_scalar(rdistrarray_name, dace.int32, transient=True, find_new_name=True) + wnode = state.add_write(scal_name) + state.add_edge(tasklet, '__out', wnode, None, Memlet.from_array(scal_name, scal)) libnode = Redistribute('_Redistribute_', rdistrarray_name) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 1b11fb00c6..60469919f5 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3302,6 +3302,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): tokens = name.split('.') name = tokens[0] true_name = None + true_array = None if name in defined_vars: true_name = defined_vars[name] if len(tokens) > 1: @@ -3356,7 +3357,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): new_data, rng = None, None dtype_keys = tuple(dtypes.dtype_to_typeclass().keys()) if not (result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or - (isinstance(result, str) and result in self.sdfg.arrays)): + (isinstance(result, str) and any(result in x for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))): raise DaceSyntaxError( self, node, "In assignments, the rhs may only be " "data, numerical/boolean constants " @@ -3380,6 +3381,14 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): _, new_data = self.sdfg.add_scalar(true_name, ttype, transient=True) self.variables[name] = true_name defined_vars[name] = true_name + if any(result in x for x in [self.sdfg._pgrids, self.sdfg._rdistrarrays, self.sdfg._subarrays]): + # NOTE: In previous versions some `pgrid` and subgrid related replacement function, + # see `dace/frontend/common/distr.py`, created dummy variables with the same name + # as the entities, such as process grids, they created. Thus the frontend was + # finding them. Since this is now disallowed, we have to explicitly handle this case. + self.variables[name] = result + defined_vars[name] = result + continue elif isinstance(result, str) and result in self.sdfg.arrays: result_data = self.sdfg.arrays[result] if (name.startswith('__return') and isinstance(result_data, data.Scalar)): diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 409d30c57a..4ae91d5ea0 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -618,6 +618,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: internally_used_symbols = self.sdfg.used_symbols(all_symbols=False) keys_to_use &= internally_used_symbols + # Translate the internal symbols back to their external counterparts. free_syms |= set().union(*(map(str, pystr_to_symbolic(v).free_symbols) for k, v in self.symbol_mapping.items() if k in keys_to_use)) @@ -662,6 +663,10 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context connectors = self.in_connectors.keys() | self.out_connectors.keys() for conn in connectors: + if conn in self.sdfg.symbols: + raise ValueError( + f'Connector "{conn}" was given, but it refers to a symbol, which is not allowed. ' + 'To pass symbols use "symbol_mapping".') if conn not in self.sdfg.arrays: raise NameError( f'Connector "{conn}" was given but is not a registered data descriptor in the nested SDFG. ' @@ -795,10 +800,8 @@ def new_symbols(self, sdfg, state, symbols) -> Dict[str, dtypes.typeclass]: for p, rng in zip(self._map.params, self._map.range): result[p] = dtypes.result_type_of(infer_expr_type(rng[0], symbols), infer_expr_type(rng[1], symbols)) - # Add dynamic inputs + # Handle the dynamic map ranges. dyn_inputs = set(c for c in self.in_connectors if not c.startswith('IN_')) - - # Try to get connector type from connector for e in state.in_edges(self): if e.dst_conn in dyn_inputs: result[e.dst_conn] = (self.in_connectors[e.dst_conn] or sdfg.arrays[e.data.data].dtype) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 84d7189ebd..5e5df1b0a2 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -746,17 +746,32 @@ def replace_dict(self, super().replace_dict(repldict, symrepl, replace_in_graph, replace_keys) - def add_symbol(self, name, stype): + def add_symbol(self, name, stype, find_new_name: bool = False): """ Adds a symbol to the SDFG. :param name: Symbol name. :param stype: Symbol type. + :param find_new_name: Find a new name. """ - if name in self.symbols: - raise FileExistsError('Symbol "%s" already exists in SDFG' % name) + if find_new_name: + name = self._find_new_name(name) + else: + # We do not check for data constant, because there is a link between the constants and + # the data descriptors. + if name in self.symbols: + raise FileExistsError(f'Symbol "{name}" already exists in SDFG') + if name in self.arrays: + raise FileExistsError(f'Can not create symbol "{name}", the name is used by a data descriptor.') + if name in self._subarrays: + raise FileExistsError(f'Can not create symbol "{name}", the name is used by a subarray.') + if name in self._rdistrarrays: + raise FileExistsError(f'Can not create symbol "{name}", the name is used by a RedistrArray.') + if name in self._pgrids: + raise FileExistsError(f'Can not create symbol "{name}", the name is used by a ProcessGrid.') if not isinstance(stype, dtypes.typeclass): stype = dtypes.dtype_to_typeclass(stype) self.symbols[name] = stype + return name def remove_symbol(self, name): """ Removes a symbol from the SDFG. @@ -1159,14 +1174,23 @@ def cast(dtype: dt.Data, value: Any): return result def add_constant(self, name: str, value: Any, dtype: dt.Data = None): - """ Adds/updates a new compile-time constant to this SDFG. A constant - may either be a scalar or a numpy ndarray thereof. + """ + Adds/updates a new compile-time constant to this SDFG. - :param name: The name of the constant. - :param value: The constant value. - :param dtype: Optional data type of the symbol, or None to deduce - automatically. + A constant may either be a scalar or a numpy ndarray thereof. It is not an + error if there is already a symbol or an array with the same name inside + the SDFG. However, the data descriptors must refer to the same type. + + :param name: The name of the constant. + :param value: The constant value. + :param dtype: Optional data type of the symbol, or None to deduce automatically. """ + if name in self._subarrays: + raise FileExistsError(f'Can not create constant "{name}", the name is used by a subarray.') + if name in self._rdistrarrays: + raise FileExistsError(f'Can not create constant "{name}", the name is used by a RedistrArray.') + if name in self._pgrids: + raise FileExistsError(f'Can not create constant "{name}", the name is used by a ProcessGrid.') self.constants_prop[name] = (dtype or dt.create_datadescriptor(value), value) @property @@ -1598,36 +1622,44 @@ def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ names = (self._arrays.keys() | self.constants_prop.keys() | self._pgrids.keys() | self._subarrays.keys() - | self._rdistrarrays.keys()) + | self._rdistrarrays.keys() | self.symbols.keys()) return dt.find_new_name(name, names) + def is_name_used(self, name: str) -> bool: + """ Checks if `name` is already used inside the SDFG.""" + if name in self._arrays: + return True + if name in self.symbols: + return True + if name in self.constants_prop: + return True + if name in self._pgrids: + return True + if name in self._subarrays: + return True + if name in self._rdistrarrays: + return True + return False + + def is_name_free(self, name: str) -> bool: + """ Test if `name` is free, i.e. is not used by anything else.""" + return not self.is_name_used(name) + def find_new_constant(self, name: str): """ - Tries to find a new constant name by adding an underscore and a number. + Tries to find a new name for a constant. """ - constants = self.constants - if name not in constants: + if self.is_name_free(name): return name - - index = 0 - while (name + ('_%d' % index)) in constants: - index += 1 - - return name + ('_%d' % index) + return self._find_new_name(name) def find_new_symbol(self, name: str): """ Tries to find a new symbol name by adding an underscore and a number. """ - symbols = self.symbols - if name not in symbols: + if self.is_name_free(name): return name - - index = 0 - while (name + ('_%d' % index)) in symbols: - index += 1 - - return name + ('_%d' % index) + return self._find_new_name(name) def add_array(self, name: str, @@ -1856,13 +1888,14 @@ def add_transient(self, def temp_data_name(self): """ Returns a temporary data descriptor name that can be used in this SDFG. """ - name = '__tmp%d' % self._temp_transients - while name in self._arrays: + + # NOTE: Consider switching to `_find_new_name` + # The frontend seems to access this variable directly. + while self.is_name_used(name): self._temp_transients += 1 name = '__tmp%d' % self._temp_transients self._temp_transients += 1 - return name def add_temp_transient(self, @@ -1917,29 +1950,47 @@ def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str """ if not isinstance(name, str): raise TypeError("Data descriptor name must be a string. Got %s" % type(name).__name__) - # If exists, fail - while name in self._arrays: - if find_new_name: - name = self._find_new_name(name) - else: - raise NameError(f'Array or Stream with name "{name}" already exists in SDFG') - # NOTE: Remove illegal characters, such as dots. Such characters may be introduced when creating views to - # members of Structures. - name = name.replace('.', '_') - assert name not in self._arrays - self._arrays[name] = datadesc - def _add_symbols(desc: dt.Data): + if find_new_name: + # These characters might be introduced through the creation of views to members + # of strictures. + # NOTES: If `find_new_name` is `True` and the name (understood as a sequence of + # any characters) is not used, i.e. `assert self.is_name_free(name)`, then it + # is still "cleaned", i.e. dots are replaced with underscores. However, if + # `find_new_name` is `False` then this cleaning is not applied and it is possible + # to create names that are formally invalid. The above code reproduces the exact + # same behaviour and is maintained for compatibility. This behaviour is + # triggered by tests/python_frontend/structures/structure_python_test.py::test_rgf`. + name = self._find_new_name(name) + name = name.replace('.', '_') + if self.is_name_used(name): + name = self._find_new_name(name) + else: + # We do not check for data constant, because there is a link between the constants and + # the data descriptors. + if name in self.arrays: + raise FileExistsError(f'Data descriptor "{name}" already exists in SDFG') + if name in self.symbols: + raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a symbol.') + if name in self._subarrays: + raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a subarray.') + if name in self._rdistrarrays: + raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a RedistrArray.') + if name in self._pgrids: + raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a ProcessGrid.') + + def _add_symbols(sdfg: SDFG, desc: dt.Data): if isinstance(desc, dt.Structure): for v in desc.members.values(): if isinstance(v, dt.Data): - _add_symbols(v) + _add_symbols(sdfg, v) for sym in desc.free_symbols: - if sym.name not in self.symbols: - self.add_symbol(sym.name, sym.dtype) + if sym.name not in sdfg.symbols: + sdfg.add_symbol(sym.name, sym.dtype) - # Add free symbols to the SDFG global symbol storage - _add_symbols(datadesc) + # Add the data descriptor to the SDFG and all symbols that are not yet known. + self._arrays[name] = datadesc + _add_symbols(self, datadesc) return name @@ -2044,9 +2095,10 @@ def add_subarray(self, newshape.append(dace.symbolic.pystr_to_symbolic(s)) subshape = newshape + # No need to ensure unique test. subarray_name = self._find_new_name('__subarray') - self._subarrays[subarray_name] = SubArray(subarray_name, dtype, shape, subshape, pgrid, correspondence) + self._subarrays[subarray_name] = SubArray(subarray_name, dtype, shape, subshape, pgrid, correspondence) self.append_init_code(self._subarrays[subarray_name].init_code()) self.append_exit_code(self._subarrays[subarray_name].exit_code()) @@ -2060,12 +2112,13 @@ def add_rdistrarray(self, array_a: str, array_b: str): :param array_b: Output sub-array descriptor. :return: Name of the new redistribution descriptor. """ + # No need to ensure unique test. + name = self._find_new_name('__rdistrarray') - rdistrarray_name = self._find_new_name('__rdistrarray') - self._rdistrarrays[rdistrarray_name] = RedistrArray(rdistrarray_name, array_a, array_b) - self.append_init_code(self._rdistrarrays[rdistrarray_name].init_code(self)) - self.append_exit_code(self._rdistrarrays[rdistrarray_name].exit_code(self)) - return rdistrarray_name + self._rdistrarrays[name] = RedistrArray(name, array_a, array_b) + self.append_init_code(self._rdistrarrays[name].init_code(self)) + self.append_exit_code(self._rdistrarrays[name].exit_code(self)) + return name def add_loop( self, diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index dd936850f0..2869743dcb 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -207,6 +207,34 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context if len(blocks) != len(set([s.label for s in blocks])): raise InvalidSDFGError('Found multiple blocks with the same name in ' + cfg.name, sdfg, None) + # Check the names of data descriptors and co. + seen_names: Set[str] = set() + for obj_names in [ + sdfg.arrays.keys(), sdfg.symbols.keys(), sdfg._rdistrarrays.keys(), sdfg._subarrays.keys() + ]: + if not seen_names.isdisjoint(obj_names): + raise InvalidSDFGError( + f'Found duplicated names: "{seen_names.intersection(obj_names)}". Please ensure ' + 'that the names of symbols, data descriptors, subarrays and rdistarrays are unique.', sdfg, None) + seen_names.update(obj_names) + + # Ensure that there is a mentioning of constants in either the array or symbol. + for const_name, (const_type, _) in sdfg.constants_prop.items(): + if const_name in sdfg.arrays: + if const_type != sdfg.arrays[const_name].dtype: + # This should actually be an error, but there is a lots of code that depends on it. + warnings.warn( + f'Mismatch between constant and data descriptor of "{const_name}", ' + f'expected to find "{const_type}" but found "{sdfg.arrays[const_name]}".') + elif const_name in sdfg.symbols: + if const_type != sdfg.symbols[const_name]: + # This should actually be an error, but there is a lots of code that depends on it. + warnings.warn( + f'Mismatch between constant and symobl type of "{const_name}", ' + f'expected to find "{const_type}" but found "{sdfg.symbols[const_name]}".') + else: + warnings.warn(f'Found constant "{const_name}" that does not refer to an array or a symbol.') + # Validate data descriptors for name, desc in sdfg._arrays.items(): if id(desc) in references: