Skip to content

Commit

Permalink
This commit fixes an error that was reported by Edoardo (@edopao).
Browse files Browse the repository at this point in the history
The bug was because the `DistributedBufferRelocator` transformation did not check if its insertion would create a read-write conflict.
This commit adds such a check, that is, however, not very sophisticated and needs some improvements.
However, the example /`model/atmosphere/dycore/tests/dycore_stencil_tests/test_compute_exner_from_rhotheta.py`) where it surfaced, does hold more challenges.
The main purpose of this commit is to unblock further development in ICON4Py.

Link to ICON4Py PR: C2SM/icon4py#638
  • Loading branch information
philip-paul-mueller committed Jan 15, 2025
1 parent 17bae8e commit d7e73f5
Show file tree
Hide file tree
Showing 2 changed files with 316 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def apply(
raise


AccessLocation: TypeAlias = tuple[dace.SDFGState, dace_nodes.AccessNode]
AccessLocation: TypeAlias = tuple[dace_nodes.AccessNode, dace.SDFGState]
"""Describes an access node and the state in which it is located.
"""

Expand All @@ -387,29 +387,38 @@ class DistributedBufferRelocator(dace_transformation.Pass):
in each branch and then in the join state written back. Thus there is some
additional storage needed.
The transformation will look for the following situation:
- A transient data container, called `src_cont`, is written into another
container, called `dst_cont`, which is not transient.
- The access node of `src_cont` has an in degree of zero and an out degree of one.
- The access node of `dst_cont` has an in degree of of one and an
- A transient data container, called `temp_storage`, is written into another
container, called `dest_storage`, which is not transient.
- The access node of `temp_storage` has an in degree of zero and an out degree of one.
- The access node of `dest_storage` has an in degree of of one and an
out degree of zero (this might be lifted).
- `src_cont` is not used afterwards.
- `dst_cont` is only used to implement the buffering.
- `temp_storage` is not used afterwards.
- `dest_storage` is only used to implement the buffering.
The function will relocate the writing of `dst_cont` to where `src_cont` is
The function will relocate the writing of `dest_storage` to where `temp_storage` is
written, which might be multiple locations.
It will also remove the writing back.
It is advised that after this transformation simplify is run again.
The relocation will not take place if it might create data race. A necessary
but not sufficient condition for a data race is if `dest_storage` is present
in the state where `temp_storage` is defined. In addition at least one of the
following conditions has to be met:
- `dest_storage`, that exists in the state is not connected to the
`temp_storage` access node.
- There is a `dest_storage` access node, but it has an output degree larger
than one.
Note:
Essentially this transformation removes the double buffering of `dst_cont`.
Because we ensure that that `dst_cont` is non transient this is okay, as our
Essentially this transformation removes the double buffering of `dest_storage`.
Because we ensure that that `dest_storage` is non transient this is okay, as our
rule guarantees this.
Todo:
- Allow that `dst_cont` can also be transient.
- Allow that `dst_cont` does not need to be a sink node, this is most
- Allow that `dest_storage` can also be transient.
- Allow that `dest_storage` does not need to be a sink node, this is most
likely most relevant if it is transient.
- Check if `dst_cont` is used between where we want to place it and
- Check if `dest_storage` is used between where we want to place it and
where it is currently used.
"""

Expand Down Expand Up @@ -489,10 +498,10 @@ def _find_candidates(
where the temporary is defined.
"""
# All nodes that are used as distributed buffers.
candidate_src_cont: list[AccessLocation] = []
candidate_temp_storage: list[AccessLocation] = []

# Which `src_cont` access node is written back to which global memory.
src_cont_to_global: dict[dace_nodes.AccessNode, str] = {}
# Which `temp_storage` access node is written back to which global memory.
temp_storage_to_global: dict[dace_nodes.AccessNode, str] = {}

for state in sdfg.states():
# These are the possible targets we want to write into.
Expand All @@ -508,26 +517,26 @@ def _find_candidates(
if len(candidate_dst_nodes) == 0:
continue

for src_cont in state.source_nodes():
if not isinstance(src_cont, dace_nodes.AccessNode):
for temp_storage in state.source_nodes():
if not isinstance(temp_storage, dace_nodes.AccessNode):
continue
if not src_cont.desc(sdfg).transient:
if not temp_storage.desc(sdfg).transient:
continue
if state.out_degree(src_cont) != 1:
if state.out_degree(temp_storage) != 1:
continue
dst_candidate: dace_nodes.AccessNode = next(
iter(edge.dst for edge in state.out_edges(src_cont))
iter(edge.dst for edge in state.out_edges(temp_storage))
)
if dst_candidate not in candidate_dst_nodes:
continue
candidate_src_cont.append((src_cont, state))
src_cont_to_global[src_cont] = dst_candidate.data
candidate_temp_storage.append((temp_storage, state))
temp_storage_to_global[temp_storage] = dst_candidate.data

if len(candidate_src_cont) == 0:
if len(candidate_temp_storage) == 0:
return []

# Now we have to find the places where the temporary sources are defined.
# I.e. This is also the location where the original value is defined.
# I.e. This is also the location where the temporary source was initialized.
result_candidates: list[tuple[AccessLocation, list[AccessLocation]]] = []

def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]:
Expand All @@ -537,39 +546,39 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]:
if dst_state in reachable[src_state] and dst_state is not src_state
}

for src_cont in candidate_src_cont:
for temp_storage in candidate_temp_storage:
def_locations: list[AccessLocation] = []
for upstream_state in find_upstream_states(src_cont[1]):
if src_cont[0].data in access_sets[upstream_state][1]:
for upstream_state in find_upstream_states(temp_storage[1]):
if temp_storage[0].data in access_sets[upstream_state][1]:
def_locations.extend(
(data_node, upstream_state)
for data_node in upstream_state.data_nodes()
if data_node.data == src_cont[0].data
if data_node.data == temp_storage[0].data
)
if len(def_locations) != 0:
result_candidates.append((src_cont, def_locations))
result_candidates.append((temp_storage, def_locations))

# This transformation removes `src_cont` by writing its content directly
# to `dst_cont`, at the point where it is defined.
# This transformation removes `temp_storage` by writing its content directly
# to `dest_storage`, at the point where it is defined.
# For this transformation to be valid the following conditions have to be met:
# - Between the definition of `src_cont` and the write back to `dst_cont`,
# `dst_cont` can not be accessed.
# - Between the definitions of `src_cont` and the point where it is written
# back, `src_cont` can only be accessed in the range that is written back.
# - After the write back point, `src_cont` shall not be accessed. This
# - Between the definition of `temp_storage` and the write back to `dest_storage`,
# `dest_storage` can not be accessed.
# - Between the definitions of `temp_storage` and the point where it is written
# back, `temp_storage` can only be accessed in the range that is written back.
# - After the write back point, `temp_storage` shall not be accessed. This
# restriction could be lifted.
#
# To keep the implementation simple, we use the conditions:
# - `src_cont` is only accessed were it is defined and at the write back
# - `temp_storage` is only accessed were it is defined and at the write back
# point.
# - Between the definitions of `src_cont` and the write back point,
# `dst_cont` is not used.
# - Between the definitions of `temp_storage` and the write back point,
# `dest_storage` is not used.

result: list[tuple[AccessLocation, list[AccessLocation]]] = []

for wb_localation, def_locations in result_candidates:
for def_node, def_state in def_locations:
# Test if `src_cont` is only accessed where it is defined and
# Test if `temp_storage` is only accessed where it is defined and
# where it is written back.
if gtx_transformations.util.is_accessed_downstream(
start_state=def_state,
Expand All @@ -579,30 +588,136 @@ def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]:
):
break
# check if the global data is not used between the definition of
# `dst_cont` and where its written back. We allow one exception,
# `dest_storage` and where its written back. We allow one exception,
# if the global data is used in the state the distributed temporary
# is defined is used only for reading then it is ignored. This is
# allowed because of rule 3 of ADR0018.
glob_nodes_in_def_state = {
dnode
for dnode in def_state.data_nodes()
if dnode.data == src_cont_to_global[wb_localation[0]]
if dnode.data == temp_storage_to_global[wb_localation[0]]
}
if any(def_state.in_degree(gdnode) != 0 for gdnode in glob_nodes_in_def_state):
break
if gtx_transformations.util.is_accessed_downstream(
start_state=def_state,
sdfg=sdfg,
data_to_look=src_cont_to_global[wb_localation[0]],
data_to_look=temp_storage_to_global[wb_localation[0]],
nodes_to_ignore=glob_nodes_in_def_state,
states_to_ignore={wb_localation[1]},
):
break
if self._check_read_write_dependency(sdfg, wb_localation, def_locations):
break
else:
result.append((wb_localation, def_locations))

return result

def _check_read_write_dependency(
self,
sdfg: dace.SDFG,
write_back_location: AccessLocation,
target_locations: list[AccessLocation],
) -> bool:
"""Tests if read-write conflicts would be created.
This function ensures that the substitution of `write_back_location` into
`target_locations` will not create a read-write conflict.
The rules that are used for this are outlined in the class description.
Args:
sdfg: The SDFG on which we operate.
write_back_location: Where currently the write back occurs.
target_locations: List of the locations were we would like to perform
the write back instead.
"""
for target_location in target_locations:
if self._check_read_write_dependency_impl(sdfg, write_back_location, target_location):
return True
return False

def _check_read_write_dependency_impl(
self,
sdfg: dace.SDFG,
write_back_location: AccessLocation,
target_location: AccessLocation,
) -> bool:
"""Tests if read-write conflict would be created for a single location.
Args:
sdfg: The SDFG on which we operate.
write_back_location: Where currently the write back occurs.
target_locations: Location where the new write back should be performed.
Todo:
Refine this checks later.
"""
assert write_back_location[0].data == target_location[0].data

write_back_state: dace.SDFGState = write_back_location[1]
write_back_node = write_back_location[0]
write_back_edge = next(iter(write_back_state.out_edges(write_back_node)))
global_data_name = write_back_edge.dst.data
assert not sdfg.arrays[global_data_name].transient

# This is the state in which we have to look for possible data races.
state_to_inspect: dace.SDFGState = target_location[1]

# This is the access not on which we will perform the write back.
def_location_of_intermediate: dace_nodes.AccessNode = target_location[0]
assert state_to_inspect.out_degree(def_location_of_intermediate) == 0

# These are all access nodes that refers to the global data, that we want
# to move into the state `state_to_inspect`. We need them to do the
# second test.
access_to_global_data_in_this_state: set[dace_nodes.AccessNode] = set()

# The first simple test is to look if global data is used and has more
# than output edge. Further analysis could show that it could be handled.
# But we do not do it.
for dnode in state_to_inspect.data_nodes():
if dnode.data != global_data_name:
continue
if state_to_inspect.out_degree(dnode) >= 2:
return True
access_to_global_data_in_this_state.add(dnode)

# There is no reference to the global data, so no need to do more tests.
if len(access_to_global_data_in_this_state) == 0:
return False

# For the second test we look if `global_data_name` is referred to in
# another data flow graph in this state, that is however, not connected
# to the graph that contains `def_location_of_intermediate`. We will
# do this by exploring the dataflow graph from that node and if we found
# a node that refers to the global data, we remove it from
# `access_to_global_data_in_this_state`, if that list is empty at the end
# then it is not used in another component.
to_process: list[dace_nodes.Node] = [def_location_of_intermediate]
seen: set[dace_nodes.Node] = set()
while len(to_process) != 0:
node = to_process.pop()
seen.add(node)

if isinstance(node, dace_nodes.AccessNode):
if node.data == global_data_name:
access_to_global_data_in_this_state.discard(node)

# Note that we only explore the ingoing edges, thus we will not necessarily
# explore the whole graph. However, this is fine, because we will see the
# relevant parts. To see that assume that we would also have to check the
# outgoing edges, this would mean that there was some branching point,
# which is a serialization point, so the dataflow would have been invalid
# before.
to_process.extend(
iedge.src for iedge in state_to_inspect.in_edges(node) if iedge.src not in seen
)

if len(access_to_global_data_in_this_state) == 0:
return False
return True


@dace_properties.make_properties
class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation):
Expand Down
Loading

0 comments on commit d7e73f5

Please sign in to comment.