Skip to content

Commit

Permalink
Merge branch 'main' into optimize_program
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N authored Nov 4, 2024
2 parents 98ab673 + eea1fb6 commit 8f15ee2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
4 changes: 1 addition & 3 deletions .gitpod.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
FROM gitpod/workspace-python
FROM gitpod/workspace-python-3.11
USER root
RUN apt-get update \
&& apt-get install -y libboost-dev \
&& apt-get clean && rm -rf /var/cache/apt/* && rm -rf /var/lib/apt/lists/* && rm -rf /tmp/*
USER gitpod
RUN pyenv install 3.10.2
RUN pyenv global 3.10.2
22 changes: 13 additions & 9 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import codegen
from gt4py.next import common
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager
Expand Down Expand Up @@ -84,7 +83,7 @@ def _process_regular_arguments(
self,
program: itir.FencilDefinition | itir.Program,
arg_types: tuple[ts.TypeSpec, ...],
offset_provider: dict[str, Connectivity | Dimension],
offset_provider: common.OffsetProvider,
) -> tuple[list[interface.Parameter], list[str]]:
parameters: list[interface.Parameter] = []
arg_exprs: list[str] = []
Expand All @@ -107,20 +106,20 @@ def _process_regular_arguments(
# translate sparse dimensions to tuple dtype
dim_name = dim.value
connectivity = offset_provider[dim_name]
assert isinstance(connectivity, Connectivity)
assert isinstance(connectivity, common.Connectivity)
size = connectivity.max_neighbors
arg = f"gridtools::sid::dimension_to_tuple_like<generated::{dim_name}_t, {size}>({arg})"
arg_exprs.append(arg)
return parameters, arg_exprs

def _process_connectivity_args(
self, offset_provider: dict[str, Connectivity | Dimension]
self, offset_provider: dict[str, common.Connectivity | common.Dimension]
) -> tuple[list[interface.Parameter], list[str]]:
parameters: list[interface.Parameter] = []
arg_exprs: list[str] = []

for name, connectivity in offset_provider.items():
if isinstance(connectivity, Connectivity):
if isinstance(connectivity, common.Connectivity):
if connectivity.index_type not in [np.int32, np.int64]:
raise ValueError(
"Neighbor table indices must be of type 'np.int32' or 'np.int64'."
Expand All @@ -131,7 +130,12 @@ def _process_connectivity_args(
interface.Parameter(
name=GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(),
type_=ts.FieldType(
dims=[connectivity.origin_axis, Dimension(name)],
dims=[
connectivity.origin_axis,
common.Dimension(
name, kind=common.DimensionKind.LOCAL
), # TODO(havogt): we should not use the name of the offset as the name of the local dimension
],
dtype=ts.ScalarType(
type_translation.get_scalar_kind(connectivity.index_type)
),
Expand All @@ -149,7 +153,7 @@ def _process_connectivity_args(
arg_exprs.append(
f"gridtools::hymap::keys<generated::{name}_t>::make_values({nbtbl})"
)
elif isinstance(connectivity, Dimension):
elif isinstance(connectivity, common.Dimension):
pass
else:
raise AssertionError(
Expand All @@ -162,7 +166,7 @@ def _process_connectivity_args(
def _preprocess_program(
self,
program: itir.FencilDefinition | itir.Program,
offset_provider: dict[str, Connectivity | Dimension],
offset_provider: dict[str, common.Connectivity | common.Dimension],
) -> itir.Program:
if isinstance(program, itir.FencilDefinition) and not self.enable_itir_transforms:
return fencil_to_program.FencilToProgram().apply(
Expand Down Expand Up @@ -196,7 +200,7 @@ def _preprocess_program(
def generate_stencil_source(
self,
program: itir.FencilDefinition | itir.Program,
offset_provider: dict[str, Connectivity | Dimension],
offset_provider: dict[str, common.Connectivity | common.Dimension],
column_axis: Optional[common.Dimension],
) -> str:
new_program = self._preprocess_program(program, offset_provider)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,12 @@ def gt_auto_optimize(
sdfg.apply_transformations_repeated(
[
dace_dataflow.TrivialMapElimination,
# TODO(phimuell): Investigate if these two are appropriate.
dace_dataflow.MapReduceFusion,
dace_dataflow.MapWCRFusion,
# TODO(phimuell): The transformation are interesting, but they have
# a bug as they assume that they are not working inside a map scope.
# Before we use them we have to fix them.
# https://chat.spcl.inf.ethz.ch/spcl/pl/8mtgtqjb378hfy7h9a96sy3nhc
# dace_dataflow.MapReduceFusion,
# dace_dataflow.MapWCRFusion,
],
validate=validate,
validate_all=validate_all,
Expand Down

0 comments on commit 8f15ee2

Please sign in to comment.