Skip to content

Commit

Permalink
Quality of live: Improved error messages (#1731)
Browse files Browse the repository at this point in the history
Small quality of live PR that improves two error messages. Especially if
the state was called `state` (as per default), it was hard to parse at
first sight


![image](https://github.com/user-attachments/assets/500cf100-57b8-4140-b4f7-d6cb0091829b)


For some reason, `yapf` was also changing the formatting of other code
parts. I thus made two commits

1. Format the two files that I was about to edit.
2. Improve the error messages in the freshly formatted files.

This allows to separate real changes from purely stylistic ones.

Would it be an idea to enforce linting with a GitHub Actions workflow?
I'm happy to write an issue about this (and contribute early next year).

---------

Co-authored-by: Roman Cattaneo <>
  • Loading branch information
romanc authored Nov 6, 2024
1 parent 163366d commit 7c70423
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 34 deletions.
20 changes: 10 additions & 10 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,9 +983,10 @@ def unparse_tasklet(sdfg, cfg, state_id, dfg, node, function_stream, callsite_st
# To prevent variables-redefinition, build dictionary with all the previously defined symbols
defined_symbols = state_dfg.symbols_defined_at(node)

defined_symbols.update(
{k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v))
for k, v in sdfg.constants.items()})
defined_symbols.update({
k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v))
for k, v in sdfg.constants.items()
})

for connector, (memlet, _, _, conntype) in memlets.items():
if connector is not None:
Expand Down Expand Up @@ -1152,7 +1153,7 @@ def _subscript_expr(self, slicenode: ast.AST, target: str) -> symbolic.SymbolicT
return sum(symbolic.pystr_to_symbolic(unparse(elt)) * s for elt, s in zip(elts, strides))

if len(strides) != 1:
raise SyntaxError('Missing dimensions in expression (expected %d, got one)' % len(strides))
raise SyntaxError('Missing dimensions in expression (expected one, got %d)' % len(strides))

try:
return symbolic.pystr_to_symbolic(unparse(visited_slice)) * strides[0]
Expand Down Expand Up @@ -1289,8 +1290,7 @@ def visit_Name(self, node: ast.Name):
if memlet.data in self.sdfg.arrays and self.sdfg.arrays[memlet.data].dtype == dtype:
return self.generic_visit(node)
return ast.parse(f"{name}[0]").body[0].value
elif (self.allow_casts and (defined_type in (DefinedType.Stream, DefinedType.StreamArray))
and memlet.dynamic):
elif (self.allow_casts and (defined_type in (DefinedType.Stream, DefinedType.StreamArray)) and memlet.dynamic):
return ast.parse(f"{name}.pop()").body[0].value
else:
return self.generic_visit(node)
Expand Down Expand Up @@ -1324,8 +1324,8 @@ def visit_BinOp(self, node: ast.BinOp):
evaluated_constant = symbolic.evaluate(unparsed, self.constants)
evaluated = symbolic.symstr(evaluated_constant, cpp_mode=True)
value = ast.parse(evaluated).body[0].value
if isinstance(evaluated_node, numbers.Number) and evaluated_node != (
value.value if sys.version_info >= (3, 8) else value.n):
if isinstance(evaluated_node, numbers.Number) and evaluated_node != (value.value if sys.version_info
>= (3, 8) else value.n):
raise TypeError
node.right = ast.parse(evaluated).body[0].value
except (TypeError, AttributeError, NameError, KeyError, ValueError, SyntaxError):
Expand Down Expand Up @@ -1378,8 +1378,8 @@ def visit_Call(self, node):


# TODO: This should be in the CUDA code generator. Add appropriate conditions to node dispatch predicate
def presynchronize_streams(sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int,
node: nodes.Node, callsite_stream: CodeIOStream):
def presynchronize_streams(sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, node: nodes.Node,
callsite_stream: CodeIOStream):
state_dfg: SDFGState = cfg.nodes()[state_id]
if hasattr(node, "_cuda_stream") or is_devicelevel_gpu(sdfg, state_dfg, node):
return
Expand Down
49 changes: 25 additions & 24 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,10 @@ def validate_control_flow_region(sdfg: 'SDFG',
also_assigned = (syms & edge.data.assignments.keys()) - {aname}
if also_assigned:
eid = region.edge_id(edge)
raise InvalidSDFGInterstateEdgeError(f'Race condition: inter-state assignment {aname} = {aval} uses '
f'variables {also_assigned}, which are also modified in the same '
'edge.', sdfg, eid)
raise InvalidSDFGInterstateEdgeError(
f'Race condition: inter-state assignment {aname} = {aval} uses '
f'variables {also_assigned}, which are also modified in the same '
'edge.', sdfg, eid)

# Add edge symbols into defined symbols
symbols.update(issyms)
Expand Down Expand Up @@ -228,9 +229,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context

# Check the names of data descriptors and co.
seen_names: Set[str] = set()
for obj_names in [
sdfg.arrays.keys(), sdfg.symbols.keys(), sdfg._rdistrarrays.keys(), sdfg._subarrays.keys()
]:
for obj_names in [sdfg.arrays.keys(), sdfg.symbols.keys(), sdfg._rdistrarrays.keys(), sdfg._subarrays.keys()]:
if not seen_names.isdisjoint(obj_names):
raise InvalidSDFGError(
f'Found duplicated names: "{seen_names.intersection(obj_names)}". Please ensure '
Expand All @@ -242,15 +241,13 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
if const_name in sdfg.arrays:
if const_type != sdfg.arrays[const_name].dtype:
# This should actually be an error, but there is a lots of code that depends on it.
warnings.warn(
f'Mismatch between constant and data descriptor of "{const_name}", '
f'expected to find "{const_type}" but found "{sdfg.arrays[const_name]}".')
warnings.warn(f'Mismatch between constant and data descriptor of "{const_name}", '
f'expected to find "{const_type}" but found "{sdfg.arrays[const_name]}".')
elif const_name in sdfg.symbols:
if const_type != sdfg.symbols[const_name]:
# This should actually be an error, but there is a lots of code that depends on it.
warnings.warn(
f'Mismatch between constant and symobl type of "{const_name}", '
f'expected to find "{const_type}" but found "{sdfg.symbols[const_name]}".')
warnings.warn(f'Mismatch between constant and symobl type of "{const_name}", '
f'expected to find "{const_type}" but found "{sdfg.symbols[const_name]}".')
else:
warnings.warn(f'Found constant "{const_name}" that does not refer to an array or a symbol.')

Expand Down Expand Up @@ -388,8 +385,7 @@ def validate_state(state: 'dace.sdfg.SDFGState',
from dace.sdfg import SDFG
from dace.sdfg import nodes as nd
from dace.sdfg import utils as sdutil
from dace.sdfg.scope import (is_devicelevel_fpga, is_devicelevel_gpu,
scope_contains_scope)
from dace.sdfg.scope import (is_devicelevel_fpga, is_devicelevel_gpu, scope_contains_scope)

sdfg = sdfg or state.parent
state_id = state_id if state_id is not None else state.parent_graph.node_id(state)
Expand Down Expand Up @@ -520,7 +516,7 @@ def validate_state(state: 'dace.sdfg.SDFGState',
# Streams do not need to be initialized
and not isinstance(arr, dt.Stream)):
if node.setzero == False:
warnings.warn('WARNING: Use of uninitialized transient "%s" in state %s' %
warnings.warn('WARNING: Use of uninitialized transient "%s" in state "%s"' %
(node.data, state.label))

# Register initialized transients
Expand Down Expand Up @@ -854,19 +850,24 @@ def validate_state(state: 'dace.sdfg.SDFGState',
read_accesses = defaultdict(list)
for node in state.data_nodes():
node_labels.append(node.label)
write_accesses[node.label].extend(
[{'subset': e.data.dst_subset, 'node': node, 'wcr': e.data.wcr} for e in state.in_edges(node)])
read_accesses[node.label].extend(
[{'subset': e.data.src_subset, 'node': node} for e in state.out_edges(node)])
write_accesses[node.label].extend([{
'subset': e.data.dst_subset,
'node': node,
'wcr': e.data.wcr
} for e in state.in_edges(node)])
read_accesses[node.label].extend([{
'subset': e.data.src_subset,
'node': node
} for e in state.out_edges(node)])

for node_label in node_labels:
writes = write_accesses[node_label]
reads = read_accesses[node_label]
# Check write-write data races.
for i in range(len(writes)):
for j in range(i+1, len(writes)):
same_or_unreachable_nodes = (writes[i]['node'] == writes[j]['node'] or
not nx.has_path(state.nx, writes[i]['node'], writes[j]['node']))
for j in range(i + 1, len(writes)):
same_or_unreachable_nodes = (writes[i]['node'] == writes[j]['node']
or not nx.has_path(state.nx, writes[i]['node'], writes[j]['node']))
no_wcr = writes[i]['wcr'] is None and writes[j]['wcr'] is None
if same_or_unreachable_nodes and no_wcr:
subsets_intersect = subsets.intersects(writes[i]['subset'], writes[j]['subset'])
Expand All @@ -875,8 +876,8 @@ def validate_state(state: 'dace.sdfg.SDFGState',
# Check read-write data races.
for write in writes:
for read in reads:
if (not nx.has_path(state.nx, read['node'], write['node']) and
subsets.intersects(write['subset'], read['subset'])):
if (not nx.has_path(state.nx, read['node'], write['node'])
and subsets.intersects(write['subset'], read['subset'])):
warnings.warn(f'Memlet range overlap while writing to "{node}" in state "{state.label}"')

########################################
Expand Down

0 comments on commit 7c70423

Please sign in to comment.