Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style[cartesian]: readability improvements and more type hints #1752

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@


def _specialize_transient_strides(sdfg: dace.SDFG, layout_map):
repldict = replace_strides(
replacement_dictionary = replace_strides(
FlorianDeconinck marked this conversation as resolved.
Show resolved Hide resolved
[array for array in sdfg.arrays.values() if array.transient], layout_map
)
sdfg.replace_dict(repldict)
sdfg.replace_dict(replacement_dictionary)
for state in sdfg.nodes():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
for k, v in repldict.items():
for k, v in replacement_dictionary.items():
if k in node.symbol_mapping:
node.symbol_mapping[k] = v
for k in repldict.keys():
for k in replacement_dictionary.keys():
if k in sdfg.symbols:
sdfg.remove_symbol(k)

Expand Down Expand Up @@ -143,7 +143,7 @@ def _to_device(sdfg: dace.SDFG, device: str) -> None:
node.device = dace.DeviceType.GPU


def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map):
def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map):
args_data = make_args_data_from_gtir(gtir_pipeline)

# stencils without effect
Expand All @@ -164,7 +164,7 @@ def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map)
return sdfg


def _post_expand_trafos(sdfg: dace.SDFG):
def _post_expand_transformations(sdfg: dace.SDFG):
# DaCe "standard" clean-up transformations
sdfg.simplify(validate=False)

Expand Down Expand Up @@ -355,7 +355,7 @@ def _unexpanded_sdfg(self):
sdfg = OirSDFGBuilder().visit(oir_node)

_to_device(sdfg, self.builder.backend.storage_info["device"])
_pre_expand_trafos(
_pre_expand_transformations(
self.builder.gtir_pipeline,
sdfg,
self.builder.backend.storage_info["layout_map"],
Expand All @@ -371,7 +371,7 @@ def unexpanded_sdfg(self):
def _expanded_sdfg(self):
sdfg = self._unexpanded_sdfg()
sdfg.expand_library_nodes()
_post_expand_trafos(sdfg)
_post_expand_transformations(sdfg)
return sdfg

def expanded_sdfg(self):
Expand Down
50 changes: 25 additions & 25 deletions src/gt4py/cartesian/gtc/dace/expansion_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def get_expansion_order_index(expansion_order, axis):
for idx, item in enumerate(expansion_order):
if isinstance(item, Iteration) and item.axis == axis:
return idx
elif isinstance(item, Map):

if isinstance(item, Map):
for it in item.iterations:
if it.kind == "contiguous" and it.axis == axis:
return idx
Expand Down Expand Up @@ -136,7 +137,9 @@ def _choose_loop_or_map(node, eo):
return eo


def _order_as_spec(computation_node, expansion_order):
def _order_as_spec(
computation_node: StencilComputation, expansion_order: Union[List[str], List[ExpansionItem]]
) -> List[ExpansionItem]:
expansion_order = list(_choose_loop_or_map(computation_node, eo) for eo in expansion_order)
expansion_specification = []
for item in expansion_order:
Expand Down Expand Up @@ -170,7 +173,7 @@ def _order_as_spec(computation_node, expansion_order):
return expansion_specification


def _populate_strides(node, expansion_specification):
def _populate_strides(node: StencilComputation, expansion_specification: List[ExpansionItem]):
"""Fill in `stride` attribute of `Iteration` and `Loop` dataclasses.

For loops, stride is set to either -1 or 1, based on iteration order.
Expand All @@ -185,10 +188,7 @@ def _populate_strides(node, expansion_specification):
for it in iterations:
if isinstance(it, Loop):
if it.stride is None:
if node.oir_node.loop_order == common.LoopOrder.BACKWARD:
it.stride = -1
else:
it.stride = 1
it.stride = -1 if node.oir_node.loop_order == common.LoopOrder.BACKWARD else 1
else:
if it.stride is None:
if it.kind == "tiling":
Expand All @@ -204,7 +204,7 @@ def _populate_strides(node, expansion_specification):
it.stride = 1


def _populate_storages(self, expansion_specification):
def _populate_storages(expansion_specification: List[ExpansionItem]):
assert all(isinstance(es, ExpansionItem) for es in expansion_specification)
innermost_axes = set(dcir.Axis.dims_3d())
tiled_axes = set()
Expand All @@ -222,7 +222,7 @@ def _populate_storages(self, expansion_specification):
tiled_axes.remove(it.axis)


def _populate_cpu_schedules(self, expansion_specification):
def _populate_cpu_schedules(expansion_specification: List[ExpansionItem]):
is_outermost = True
for es in expansion_specification:
if isinstance(es, Map):
Expand All @@ -234,7 +234,7 @@ def _populate_cpu_schedules(self, expansion_specification):
es.schedule = dace.ScheduleType.Default


def _populate_gpu_schedules(self, expansion_specification):
def _populate_gpu_schedules(expansion_specification: List[ExpansionItem]):
# On GPU if any dimension is tiled and has a contiguous map in the same axis further in
# pick those two maps as Device/ThreadBlock maps. If not, Make just device map with
# default blocksizes
Expand Down Expand Up @@ -267,16 +267,16 @@ def _populate_gpu_schedules(self, expansion_specification):
es.schedule = dace.ScheduleType.Default


def _populate_schedules(self, expansion_specification):
def _populate_schedules(node: StencilComputation, expansion_specification: List[ExpansionItem]):
assert all(isinstance(es, ExpansionItem) for es in expansion_specification)
assert hasattr(self, "_device")
if self.device == dace.DeviceType.GPU:
_populate_gpu_schedules(self, expansion_specification)
assert hasattr(node, "_device")
if node.device == dace.DeviceType.GPU:
_populate_gpu_schedules(expansion_specification)
else:
_populate_cpu_schedules(self, expansion_specification)
_populate_cpu_schedules(expansion_specification)


def _collapse_maps_gpu(self, expansion_specification):
def _collapse_maps_gpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]:
def _union_map_items(last_item, next_item):
if last_item.schedule == next_item.schedule:
return (
Expand Down Expand Up @@ -307,7 +307,7 @@ def _union_map_items(last_item, next_item):
),
)

res_items = []
res_items: List[ExpansionItem] = []
for item in expansion_specification:
if isinstance(item, Map):
if not res_items or not isinstance(res_items[-1], Map):
Expand All @@ -324,8 +324,8 @@ def _union_map_items(last_item, next_item):
return res_items


def _collapse_maps_cpu(self, expansion_specification):
res_items = []
def _collapse_maps_cpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]:
res_items: List[ExpansionItem] = []
for item in expansion_specification:
if isinstance(item, Map):
if (
Expand Down Expand Up @@ -360,12 +360,12 @@ def _collapse_maps_cpu(self, expansion_specification):
return res_items


def _collapse_maps(self, expansion_specification):
assert hasattr(self, "_device")
if self.device == dace.DeviceType.GPU:
res_items = _collapse_maps_gpu(self, expansion_specification)
def _collapse_maps(node: StencilComputation, expansion_specification: List[ExpansionItem]):
assert hasattr(node, "_device")
if node.device == dace.DeviceType.GPU:
res_items = _collapse_maps_gpu(expansion_specification)
else:
res_items = _collapse_maps_cpu(self, expansion_specification)
res_items = _collapse_maps_cpu(expansion_specification)
expansion_specification.clear()
expansion_specification.extend(res_items)

Expand All @@ -387,7 +387,7 @@ def make_expansion_order(
_populate_strides(node, expansion_specification)
_populate_schedules(node, expansion_specification)
_collapse_maps(node, expansion_specification)
_populate_storages(node, expansion_specification)
_populate_storages(expansion_specification)
return expansion_specification


Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def visit_VerticalLoop(
state.add_edge(
access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset)
)

for field in access_collection.write_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
library_node.add_out_connector("__out_" + field)
Expand All @@ -131,8 +132,6 @@ def visit_VerticalLoop(
library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset)
)

return

def visit_Stencil(self, node: oir.Stencil, **kwargs):
ctx = OirSDFGBuilder.SDFGContext(stencil=node)
for param in node.params:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def array_dimensions(array: dace.data.Array):
return dims


def replace_strides(arrays, get_layout_map):
def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, str]:
symbol_mapping = {}
for array in arrays:
dims = array_dimensions(array)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def stencil(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.int_]

@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_mask_with_offset_written_in_conditional(backend):
@gtscript.stencil(backend, externals={"mord": 5})
@gtscript.stencil(backend)
def stencil(outp: gtscript.Field[np.float_]):
with computation(PARALLEL), interval(...):
cond = True
Expand Down