Skip to content

Commit

Permalink
bug[next]: fix missing local kind in gtfn connectivity (#1715)
Browse files Browse the repository at this point in the history
The second dimension of a connectivity is a local dimension. Before we
defaulted to make this dimension horizontal. Currently, this information
is not used.
  • Loading branch information
havogt authored Nov 4, 2024
1 parent 725b6ba commit eea1fb6
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import codegen
from gt4py.next import common
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager
Expand Down Expand Up @@ -84,7 +83,7 @@ def _process_regular_arguments(
self,
program: itir.FencilDefinition | itir.Program,
arg_types: tuple[ts.TypeSpec, ...],
offset_provider: dict[str, Connectivity | Dimension],
offset_provider: common.OffsetProvider,
) -> tuple[list[interface.Parameter], list[str]]:
parameters: list[interface.Parameter] = []
arg_exprs: list[str] = []
Expand All @@ -107,20 +106,20 @@ def _process_regular_arguments(
# translate sparse dimensions to tuple dtype
dim_name = dim.value
connectivity = offset_provider[dim_name]
assert isinstance(connectivity, Connectivity)
assert isinstance(connectivity, common.Connectivity)
size = connectivity.max_neighbors
arg = f"gridtools::sid::dimension_to_tuple_like<generated::{dim_name}_t, {size}>({arg})"
arg_exprs.append(arg)
return parameters, arg_exprs

def _process_connectivity_args(
self, offset_provider: dict[str, Connectivity | Dimension]
self, offset_provider: dict[str, common.Connectivity | common.Dimension]
) -> tuple[list[interface.Parameter], list[str]]:
parameters: list[interface.Parameter] = []
arg_exprs: list[str] = []

for name, connectivity in offset_provider.items():
if isinstance(connectivity, Connectivity):
if isinstance(connectivity, common.Connectivity):
if connectivity.index_type not in [np.int32, np.int64]:
raise ValueError(
"Neighbor table indices must be of type 'np.int32' or 'np.int64'."
Expand All @@ -131,7 +130,12 @@ def _process_connectivity_args(
interface.Parameter(
name=GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(),
type_=ts.FieldType(
dims=[connectivity.origin_axis, Dimension(name)],
dims=[
connectivity.origin_axis,
common.Dimension(
name, kind=common.DimensionKind.LOCAL
), # TODO(havogt): we should not use the name of the offset as the name of the local dimension
],
dtype=ts.ScalarType(
type_translation.get_scalar_kind(connectivity.index_type)
),
Expand All @@ -149,7 +153,7 @@ def _process_connectivity_args(
arg_exprs.append(
f"gridtools::hymap::keys<generated::{name}_t>::make_values({nbtbl})"
)
elif isinstance(connectivity, Dimension):
elif isinstance(connectivity, common.Dimension):
pass
else:
raise AssertionError(
Expand All @@ -162,7 +166,7 @@ def _process_connectivity_args(
def _preprocess_program(
self,
program: itir.FencilDefinition | itir.Program,
offset_provider: dict[str, Connectivity | Dimension],
offset_provider: dict[str, common.Connectivity | common.Dimension],
) -> itir.Program:
if isinstance(program, itir.FencilDefinition) and not self.enable_itir_transforms:
return fencil_to_program.FencilToProgram().apply(
Expand Down Expand Up @@ -196,7 +200,7 @@ def _preprocess_program(
def generate_stencil_source(
self,
program: itir.FencilDefinition | itir.Program,
offset_provider: dict[str, Connectivity | Dimension],
offset_provider: dict[str, common.Connectivity | common.Dimension],
column_axis: Optional[common.Dimension],
) -> str:
new_program = self._preprocess_program(program, offset_provider)
Expand Down

0 comments on commit eea1fb6

Please sign in to comment.