From 7c70423507dc37e877076d39d4fd80f16e312006 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 6 Nov 2024 09:46:38 +0100 Subject: [PATCH] Quality of live: Improved error messages (#1731) Small quality of live PR that improves two error messages. Especially if the state was called `state` (as per default), it was hard to parse at first sight ![image](https://github.com/user-attachments/assets/500cf100-57b8-4140-b4f7-d6cb0091829b) For some reason, `yapf` was also changing the formatting of other code parts. I thus made two commits 1. Format the two files that I was about to edit. 2. Improve the error messages in the freshly formatted files. This allows to separate real changes from purely stylistic ones. Would it be an idea to enforce linting with a GitHub Actions workflow? I'm happy to write an issue about this (and contribute early next year). --------- Co-authored-by: Roman Cattaneo <> --- dace/codegen/targets/cpp.py | 20 +++++++-------- dace/sdfg/validation.py | 49 +++++++++++++++++++------------------ 2 files changed, 35 insertions(+), 34 deletions(-) diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 86942874d1..ed52bd093a 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -983,9 +983,10 @@ def unparse_tasklet(sdfg, cfg, state_id, dfg, node, function_stream, callsite_st # To prevent variables-redefinition, build dictionary with all the previously defined symbols defined_symbols = state_dfg.symbols_defined_at(node) - defined_symbols.update( - {k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v)) - for k, v in sdfg.constants.items()}) + defined_symbols.update({ + k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v)) + for k, v in sdfg.constants.items() + }) for connector, (memlet, _, _, conntype) in memlets.items(): if connector is not None: @@ -1152,7 +1153,7 @@ def _subscript_expr(self, slicenode: ast.AST, target: str) -> symbolic.SymbolicT return sum(symbolic.pystr_to_symbolic(unparse(elt)) * s for elt, s in zip(elts, strides)) if len(strides) != 1: - raise SyntaxError('Missing dimensions in expression (expected %d, got one)' % len(strides)) + raise SyntaxError('Missing dimensions in expression (expected one, got %d)' % len(strides)) try: return symbolic.pystr_to_symbolic(unparse(visited_slice)) * strides[0] @@ -1289,8 +1290,7 @@ def visit_Name(self, node: ast.Name): if memlet.data in self.sdfg.arrays and self.sdfg.arrays[memlet.data].dtype == dtype: return self.generic_visit(node) return ast.parse(f"{name}[0]").body[0].value - elif (self.allow_casts and (defined_type in (DefinedType.Stream, DefinedType.StreamArray)) - and memlet.dynamic): + elif (self.allow_casts and (defined_type in (DefinedType.Stream, DefinedType.StreamArray)) and memlet.dynamic): return ast.parse(f"{name}.pop()").body[0].value else: return self.generic_visit(node) @@ -1324,8 +1324,8 @@ def visit_BinOp(self, node: ast.BinOp): evaluated_constant = symbolic.evaluate(unparsed, self.constants) evaluated = symbolic.symstr(evaluated_constant, cpp_mode=True) value = ast.parse(evaluated).body[0].value - if isinstance(evaluated_node, numbers.Number) and evaluated_node != ( - value.value if sys.version_info >= (3, 8) else value.n): + if isinstance(evaluated_node, numbers.Number) and evaluated_node != (value.value if sys.version_info + >= (3, 8) else value.n): raise TypeError node.right = ast.parse(evaluated).body[0].value except (TypeError, AttributeError, NameError, KeyError, ValueError, SyntaxError): @@ -1378,8 +1378,8 @@ def visit_Call(self, node): # TODO: This should be in the CUDA code generator. Add appropriate conditions to node dispatch predicate -def presynchronize_streams(sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, - node: nodes.Node, callsite_stream: CodeIOStream): +def presynchronize_streams(sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, node: nodes.Node, + callsite_stream: CodeIOStream): state_dfg: SDFGState = cfg.nodes()[state_id] if hasattr(node, "_cuda_stream") or is_devicelevel_gpu(sdfg, state_dfg, node): return diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 0b1d946798..1f5c263206 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -119,9 +119,10 @@ def validate_control_flow_region(sdfg: 'SDFG', also_assigned = (syms & edge.data.assignments.keys()) - {aname} if also_assigned: eid = region.edge_id(edge) - raise InvalidSDFGInterstateEdgeError(f'Race condition: inter-state assignment {aname} = {aval} uses ' - f'variables {also_assigned}, which are also modified in the same ' - 'edge.', sdfg, eid) + raise InvalidSDFGInterstateEdgeError( + f'Race condition: inter-state assignment {aname} = {aval} uses ' + f'variables {also_assigned}, which are also modified in the same ' + 'edge.', sdfg, eid) # Add edge symbols into defined symbols symbols.update(issyms) @@ -228,9 +229,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context # Check the names of data descriptors and co. seen_names: Set[str] = set() - for obj_names in [ - sdfg.arrays.keys(), sdfg.symbols.keys(), sdfg._rdistrarrays.keys(), sdfg._subarrays.keys() - ]: + for obj_names in [sdfg.arrays.keys(), sdfg.symbols.keys(), sdfg._rdistrarrays.keys(), sdfg._subarrays.keys()]: if not seen_names.isdisjoint(obj_names): raise InvalidSDFGError( f'Found duplicated names: "{seen_names.intersection(obj_names)}". Please ensure ' @@ -242,15 +241,13 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context if const_name in sdfg.arrays: if const_type != sdfg.arrays[const_name].dtype: # This should actually be an error, but there is a lots of code that depends on it. - warnings.warn( - f'Mismatch between constant and data descriptor of "{const_name}", ' - f'expected to find "{const_type}" but found "{sdfg.arrays[const_name]}".') + warnings.warn(f'Mismatch between constant and data descriptor of "{const_name}", ' + f'expected to find "{const_type}" but found "{sdfg.arrays[const_name]}".') elif const_name in sdfg.symbols: if const_type != sdfg.symbols[const_name]: # This should actually be an error, but there is a lots of code that depends on it. - warnings.warn( - f'Mismatch between constant and symobl type of "{const_name}", ' - f'expected to find "{const_type}" but found "{sdfg.symbols[const_name]}".') + warnings.warn(f'Mismatch between constant and symobl type of "{const_name}", ' + f'expected to find "{const_type}" but found "{sdfg.symbols[const_name]}".') else: warnings.warn(f'Found constant "{const_name}" that does not refer to an array or a symbol.') @@ -388,8 +385,7 @@ def validate_state(state: 'dace.sdfg.SDFGState', from dace.sdfg import SDFG from dace.sdfg import nodes as nd from dace.sdfg import utils as sdutil - from dace.sdfg.scope import (is_devicelevel_fpga, is_devicelevel_gpu, - scope_contains_scope) + from dace.sdfg.scope import (is_devicelevel_fpga, is_devicelevel_gpu, scope_contains_scope) sdfg = sdfg or state.parent state_id = state_id if state_id is not None else state.parent_graph.node_id(state) @@ -520,7 +516,7 @@ def validate_state(state: 'dace.sdfg.SDFGState', # Streams do not need to be initialized and not isinstance(arr, dt.Stream)): if node.setzero == False: - warnings.warn('WARNING: Use of uninitialized transient "%s" in state %s' % + warnings.warn('WARNING: Use of uninitialized transient "%s" in state "%s"' % (node.data, state.label)) # Register initialized transients @@ -854,19 +850,24 @@ def validate_state(state: 'dace.sdfg.SDFGState', read_accesses = defaultdict(list) for node in state.data_nodes(): node_labels.append(node.label) - write_accesses[node.label].extend( - [{'subset': e.data.dst_subset, 'node': node, 'wcr': e.data.wcr} for e in state.in_edges(node)]) - read_accesses[node.label].extend( - [{'subset': e.data.src_subset, 'node': node} for e in state.out_edges(node)]) + write_accesses[node.label].extend([{ + 'subset': e.data.dst_subset, + 'node': node, + 'wcr': e.data.wcr + } for e in state.in_edges(node)]) + read_accesses[node.label].extend([{ + 'subset': e.data.src_subset, + 'node': node + } for e in state.out_edges(node)]) for node_label in node_labels: writes = write_accesses[node_label] reads = read_accesses[node_label] # Check write-write data races. for i in range(len(writes)): - for j in range(i+1, len(writes)): - same_or_unreachable_nodes = (writes[i]['node'] == writes[j]['node'] or - not nx.has_path(state.nx, writes[i]['node'], writes[j]['node'])) + for j in range(i + 1, len(writes)): + same_or_unreachable_nodes = (writes[i]['node'] == writes[j]['node'] + or not nx.has_path(state.nx, writes[i]['node'], writes[j]['node'])) no_wcr = writes[i]['wcr'] is None and writes[j]['wcr'] is None if same_or_unreachable_nodes and no_wcr: subsets_intersect = subsets.intersects(writes[i]['subset'], writes[j]['subset']) @@ -875,8 +876,8 @@ def validate_state(state: 'dace.sdfg.SDFGState', # Check read-write data races. for write in writes: for read in reads: - if (not nx.has_path(state.nx, read['node'], write['node']) and - subsets.intersects(write['subset'], read['subset'])): + if (not nx.has_path(state.nx, read['node'], write['node']) + and subsets.intersects(write['subset'], read['subset'])): warnings.warn(f'Memlet range overlap while writing to "{node}" in state "{state.label}"') ########################################