From bf1f1f8143217a2337dd613d7eb80c4d63759ce9 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Jan 2024 21:47:31 +0100 Subject: [PATCH 01/50] skip value connectivity --- src/gt4py/next/embedded/nd_array_field.py | 32 ++- src/gt4py/next/ffront/fbuiltins.py | 1 - src/gt4py/next/iterator/atlas_utils.py | 11 +- .../ffront_tests/test_ffront_fvm_nabla.py | 126 +++++++++++ .../ffront_tests/test_skip_value.py | 50 +++++ .../multi_feature_tests/fvm_nabla_setup.py | 212 ++++++++++++++++++ .../iterator_tests/test_fvm_nabla.py | 2 +- 7 files changed, 428 insertions(+), 6 deletions(-) create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9fc1b42038..f3a4c9409a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -26,7 +26,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar from gt4py.next import common -from gt4py.next.embedded import common as embedded_common +from gt4py.next.embedded import common as embedded_common, context as embedded_context from gt4py.next.ffront import fbuiltins @@ -160,6 +160,8 @@ def from_array( def remap( self: NdArrayField, connectivity: common.ConnectivityField | fbuiltins.FieldOffset ) -> NdArrayField: + # TODO skip values: if the skip_value is -1 we don't need special treatment, we'll just select a random value (the wrapped around one) + # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField if not common.is_connectivity_field(connectivity): assert isinstance(connectivity, fbuiltins.FieldOffset) @@ -384,7 +386,16 @@ def inverse_image( assert isinstance(image_range, common.UnitRange) assert common.UnitRange.is_finite(image_range) - restricted_mask = (self._ndarray >= image_range.start) & ( + + # HANNES HACK + # map all negative indices (skip_values) to the smallest positiv index + smallest = xp.min( + self._ndarray, initial=xp.iinfo(xp.int32).max, where=self._ndarray >= 0 + ) + clipped_array = xp.where(self._ndarray < 0, smallest, self._ndarray) + # END HANNES HACK + + restricted_mask = (clipped_array >= image_range.start) & ( self._ndarray < image_range.stop ) # indices of non-zero elements in each dimension @@ -485,9 +496,24 @@ def _builtin_op( if axis not in field.domain.dims: raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.") reduce_dim_index = field.domain.dims.index(axis) + # HANNES HACK + current_offset_provider = embedded_context.offset_provider.get(None) + assert current_offset_provider is not None + offset_definition = current_offset_provider[ + axis.value + ] # assumes offset and local dimension have same name + # TODO mapping of connectivity to field dimensions + # TODO unclear in case of multiple local dimensions (probably just chain) + # END HANNES HACK new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) + return field.__class__.from_array( - getattr(field.array_ns, array_builtin_name)(field.ndarray, axis=reduce_dim_index), + getattr(field.array_ns, array_builtin_name)( + field.ndarray, + axis=reduce_dim_index, + initial=0, # set proper inital value + where=offset_definition.table >= 0, + ), domain=new_domain, ) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 278dde9180..4787112e7d 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -364,7 +364,6 @@ def as_connectivity_field(self): if common.is_connectivity_field(offset_definition): connectivity = offset_definition elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider): - assert not offset_definition.has_skip_values connectivity = gtx.as_connectivity( domain=self.target, codomain=self.source, diff --git a/src/gt4py/next/iterator/atlas_utils.py b/src/gt4py/next/iterator/atlas_utils.py index 500c9253ff..aad77ccc9a 100644 --- a/src/gt4py/next/iterator/atlas_utils.py +++ b/src/gt4py/next/iterator/atlas_utils.py @@ -29,7 +29,7 @@ def __getitem__(self, indices): if neigh_index < self.atlas_connectivity.cols(primary_index): return self.atlas_connectivity[primary_index, neigh_index] else: - return None + return -1 else: if neigh_index < 2: return self.atlas_connectivity[primary_index, neigh_index] @@ -53,3 +53,12 @@ def max(self): # noqa: A003 if v is not None: maximum = max(maximum, v) return maximum + + def asnumpy(self): + import numpy as np + + res = np.empty(self.shape, dtype=self.dtype) + for i in range(self.shape[0]): + for j in range(self.shape[1]): + res[i, j] = self[i, j] + return res diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py new file mode 100644 index 0000000000..efbf4ddbc7 --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -0,0 +1,126 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Tuple + +import numpy as np +import pytest + +from gt4py import next as gtx +from gt4py.next import allocators, neighbor_sum +from gt4py.next.iterator import atlas_utils +from gt4py.next.program_processors import processor_interface as ppi + +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + fieldview_backend, +) +from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( + assert_close, + nabla_setup, +) + + +Vertex = gtx.Dimension("Vertex") +Edge = gtx.Dimension("Edge") +V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) + +V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) +E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) + + +@gtx.field_operator +def compute_zavgS( + pp: gtx.Field[[Vertex], float], S_M: gtx.Field[[Edge], float] +) -> gtx.Field[[Edge], float]: + zavg = 0.5 * (pp(E2V[0]) + pp(E2V[1])) + return S_M * zavg + + +@gtx.field_operator +def compute_pnabla( + pp: gtx.Field[[Vertex], float], + S_M: gtx.Field[[Edge], float], + sign: gtx.Field[[Vertex, V2EDim], float], + vol: gtx.Field[[Vertex], float], +) -> gtx.Field[[Vertex], float]: + zavgS = compute_zavgS(pp, S_M) + pnabla_M = neighbor_sum(zavgS(V2E) * sign, axis=V2EDim) + return pnabla_M / vol + + +@gtx.field_operator +def pnabla( + pp: gtx.Field[[Vertex], float], + S_M: Tuple[gtx.Field[[Edge], float], gtx.Field[[Edge], float]], + sign: gtx.Field[[Vertex, V2EDim], float], + vol: gtx.Field[[Vertex], float], +) -> Tuple[gtx.Field[[Vertex], float], gtx.Field[[Vertex], float]]: + return compute_pnabla(pp, S_M[0], sign, vol), compute_pnabla(pp, S_M[1], sign, vol) + + +def test_ffront_compute_zavgS(fieldview_backend): + allocator = fieldview_backend + fieldview_backend = ( + fieldview_backend if hasattr(fieldview_backend, "executor") else None + ) # TODO this pattern is dangerous + + setup = nabla_setup() + + pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) + S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) + + zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=allocator) + + e2v = gtx.NeighborTableOffsetProvider( + atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False + ) + + compute_zavgS.with_backend(fieldview_backend)( + pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} + ) + + assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) + assert_close(388241977.58389181, np.max(zavgS.asnumpy())) + + +def test_ffront_nabla(fieldview_backend): + fieldview_backend = fieldview_backend if hasattr(fieldview_backend, "executor") else None + + setup = nabla_setup() + + sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) + pp = gtx.as_field([Vertex], setup.input_field) + S_M = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) + vol = gtx.as_field([Vertex], setup.vol_field) + + pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) + pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) + + e2v = gtx.NeighborTableOffsetProvider( + atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False + ) + v2e = gtx.NeighborTableOffsetProvider( + atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 + ) + + pnabla.with_backend(fieldview_backend)( + pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} + ) + + # TODO this check is not sensitive enough, need to implement a proper numpy reference! + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX.asnumpy())) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX.asnumpy())) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY.asnumpy())) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY.asnumpy())) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py new file mode 100644 index 0000000000..7c3e69eade --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py @@ -0,0 +1,50 @@ +import numpy as np + +import gt4py.next as gtx +from gt4py.next import float64, neighbor_sum + + +CellDim = gtx.Dimension("Cell") +EdgeDim = gtx.Dimension("Edge") +E2CDim = gtx.Dimension("E2C", kind=gtx.DimensionKind.LOCAL) +E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim,E2CDim)) + +edge_to_cell_table = np.array([ + [0, -1], # edge 0 (neighbours: cell 0) + [2, -1], # edge 1 + [2, -1], # edge 2 + [3, -1], # edge 3 + [4, -1], # edge 4 + [5, -1], # edge 5 + [0, 5], # edge 6 (neighbours: cell 0, cell 5) + [0, 1], # edge 7 + [1, 2], # edge 8 + [1, 3], # edge 9 + [3, 4], # edge 10 + [4, 5] # edge 11 +]) + +E2C_offset_provider = gtx.NeighborTableOffsetProvider(edge_to_cell_table, EdgeDim, CellDim, 2, has_skip_values=True) + +@gtx.field_operator +def sum_adjacent_cells(cells : gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: + # type of cells(E2C) is gtx.Field[[CellDim, E2CDim], float64] + return neighbor_sum(cells(E2C), axis=E2CDim) + +@gtx.program +def run_sum_adjacent_cells(cells : gtx.Field[[CellDim], float64], out : gtx.Field[[EdgeDim], float64]): + sum_adjacent_cells(cells, out=out) + +cell_values = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) +edge_values = gtx.as_field([EdgeDim], np.zeros((12,))) + +run_sum_adjacent_cells(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) + +print(np.clip(edge_to_cell_table, 0,5)) +ref = cell_values.ndarray[np.clip(edge_to_cell_table, 0,5)] +print(ref) +ref = np.sum(ref, initial=0, where=edge_to_cell_table>=0, axis=1) + + +print("sum of adjacent cells: {}".format(edge_values.asnumpy())) +print(ref) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py new file mode 100644 index 0000000000..03e1af27dd --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py @@ -0,0 +1,212 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import math + +import numpy as np +from atlas4py import ( + Config, + StructuredGrid, + StructuredMeshGenerator, + Topology, + build_edges, + build_median_dual_mesh, + build_node_to_edge_connectivity, + functionspace, +) + + +def assert_close(expected, actual): + assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual) + + +class nabla_setup: + @staticmethod + def _default_config(): + config = Config() + config["triangulate"] = True + config["angle"] = 20.0 + return config + + def __init__(self, *, grid=StructuredGrid("O32"), config=None): + if config is None: + config = self._default_config() + mesh = StructuredMeshGenerator(config).generate(grid) + + fs_edges = functionspace.EdgeColumns(mesh, halo=1) + fs_nodes = functionspace.NodeColumns(mesh, halo=1) + + build_edges(mesh) + build_node_to_edge_connectivity(mesh) + build_median_dual_mesh(mesh) + + edges_per_node = max( + [mesh.nodes.edge_connectivity.cols(node) for node in range(0, fs_nodes.size)] + ) + + self.mesh = mesh + self.fs_edges = fs_edges + self.fs_nodes = fs_nodes + self.edges_per_node = edges_per_node + + @property + def edges2node_connectivity(self): + return self.mesh.edges.node_connectivity + + @property + def nodes2edge_connectivity(self): + return self.mesh.nodes.edge_connectivity + + @property + def nodes_size(self): + return self.fs_nodes.size + + @property + def edges_size(self): + return self.fs_edges.size + + @staticmethod + def _is_pole_edge(e, edge_flags): + return Topology.check(edge_flags[e], Topology.POLE) + + @property + def is_pole_edge_field(self): + edge_flags = np.array(self.mesh.edges.flags()) + + pole_edge_field = np.zeros((self.edges_size,), dtype=bool) + for e in range(self.edges_size): + pole_edge_field[e] = self._is_pole_edge(e, edge_flags) + return pole_edge_field + + @property + def sign_field(self): + node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) + edge_flags = np.array(self.mesh.edges.flags()) + + for jnode in range(0, self.nodes_size): + node_edge_con = self.mesh.nodes.edge_connectivity + edge_node_con = self.mesh.edges.node_connectivity + for jedge in range(0, node_edge_con.cols(jnode)): + iedge = node_edge_con[jnode, jedge] + ip1 = edge_node_con[iedge, 0] + if jnode == ip1: + node2edge_sign[jnode, jedge] = 1.0 + else: + node2edge_sign[jnode, jedge] = -1.0 + if self._is_pole_edge(iedge, edge_flags): + node2edge_sign[jnode, jedge] = 1.0 + return node2edge_sign + + @property + def S_fields(self): + S = np.array(self.mesh.edges.field("dual_normals"), copy=False) + S_MXX = np.zeros((self.edges_size)) + S_MYY = np.zeros((self.edges_size)) + + MXX = 0 + MYY = 1 + + rpi = 2.0 * math.asin(1.0) + radius = 6371.22e03 + deg2rad = 2.0 * rpi / 360.0 + + for i in range(0, self.edges_size): + S_MXX[i] = S[i, MXX] * radius * deg2rad + S_MYY[i] = S[i, MYY] * radius * deg2rad + + assert math.isclose(min(S_MXX), -103437.60479272791) + assert math.isclose(max(S_MXX), 340115.33913622628) + assert math.isclose(min(S_MYY), -2001577.7946404363) + assert math.isclose(max(S_MYY), 2001577.7946404363) + + return S_MXX, S_MYY + + @property + def vol_field(self): + rpi = 2.0 * math.asin(1.0) + radius = 6371.22e03 + deg2rad = 2.0 * rpi / 360.0 + vol_atlas = np.array(self.mesh.nodes.field("dual_volumes"), copy=False) + # dual_volumes 4.6510228700066421 68.891611253882218 12.347560975609632 + assert_close(4.6510228700066421, min(vol_atlas)) + assert_close(68.891611253882218, max(vol_atlas)) + + vol = np.zeros((vol_atlas.size)) + for i in range(0, vol_atlas.size): + vol[i] = vol_atlas[i] * pow(deg2rad, 2) * pow(radius, 2) + # VOL(min/max): 57510668192.214096 851856184496.32886 + assert_close(57510668192.214096, min(vol)) + assert_close(851856184496.32886, max(vol)) + return vol + + @property + def input_field(self): + klevel = 0 + MXX = 0 + MYY = 1 + rpi = 2.0 * math.asin(1.0) + radius = 6371.22e03 + deg2rad = 2.0 * rpi / 360.0 + + zh0 = 2000.0 + zrad = 3.0 * rpi / 4.0 * radius + zeta = rpi / 16.0 * radius + zlatc = 0.0 + zlonc = 3.0 * rpi / 2.0 + + m_rlonlatcr = self.fs_nodes.create_field( + name="m_rlonlatcr", + levels=1, + dtype=np.float64, + variables=self.edges_per_node, + ) + rlonlatcr = np.array(m_rlonlatcr, copy=False) + + m_rcoords = self.fs_nodes.create_field( + name="m_rcoords", levels=1, dtype=np.float64, variables=self.edges_per_node + ) + rcoords = np.array(m_rcoords, copy=False) + + m_rcosa = self.fs_nodes.create_field(name="m_rcosa", levels=1, dtype=np.float64) + rcosa = np.array(m_rcosa, copy=False) + + m_rsina = self.fs_nodes.create_field(name="m_rsina", levels=1, dtype=np.float64) + rsina = np.array(m_rsina, copy=False) + + m_pp = self.fs_nodes.create_field(name="m_pp", levels=1, dtype=np.float64) + rzs = np.array(m_pp, copy=False) + + rcoords_deg = np.array(self.mesh.nodes.field("lonlat")) + + for jnode in range(0, self.nodes_size): + for i in range(0, 2): + rcoords[jnode, klevel, i] = rcoords_deg[jnode, i] * deg2rad + rlonlatcr[jnode, klevel, i] = rcoords[jnode, klevel, i] # This is not my pattern! + rcosa[jnode, klevel] = math.cos(rlonlatcr[jnode, klevel, MYY]) + rsina[jnode, klevel] = math.sin(rlonlatcr[jnode, klevel, MYY]) + for jnode in range(0, self.nodes_size): + zlon = rlonlatcr[jnode, klevel, MXX] + zdist = math.sin(zlatc) * rsina[jnode, klevel] + math.cos(zlatc) * rcosa[ + jnode, klevel + ] * math.cos(zlon - zlonc) + zdist = radius * math.acos(zdist) + rzs[jnode, klevel] = 0.0 + if zdist < zrad: + rzs[jnode, klevel] = rzs[jnode, klevel] + 0.5 * zh0 * ( + 1.0 + math.cos(rpi * zdist / zrad) + ) * math.pow(math.cos(rpi * zdist / zeta), 2) + + assert_close(0.0000000000000000, min(rzs)) + assert_close(1965.4980340735883, max(rzs)) + return rzs[:, klevel] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index e1d959aba9..9ee364c014 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -38,7 +38,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from gt4py.next.iterator.transforms.pass_manager import LiftMode -from next_tests.integration_tests.multi_feature_tests.iterator_tests.fvm_nabla_setup import ( +from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, nabla_setup, ) From 54959372d50f6e18135fd7892919f34b8949e3e1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 23 Jan 2024 16:21:12 +0100 Subject: [PATCH 02/50] fix formatting --- .../ffront_tests/test_skip_value.py | 69 ++++++++++++------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py index 7c3e69eade..6657ed1b62 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py @@ -1,3 +1,17 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + import numpy as np import gt4py.next as gtx @@ -7,43 +21,52 @@ CellDim = gtx.Dimension("Cell") EdgeDim = gtx.Dimension("Edge") E2CDim = gtx.Dimension("E2C", kind=gtx.DimensionKind.LOCAL) -E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim,E2CDim)) - -edge_to_cell_table = np.array([ - [0, -1], # edge 0 (neighbours: cell 0) - [2, -1], # edge 1 - [2, -1], # edge 2 - [3, -1], # edge 3 - [4, -1], # edge 4 - [5, -1], # edge 5 - [0, 5], # edge 6 (neighbours: cell 0, cell 5) - [0, 1], # edge 7 - [1, 2], # edge 8 - [1, 3], # edge 9 - [3, 4], # edge 10 - [4, 5] # edge 11 -]) - -E2C_offset_provider = gtx.NeighborTableOffsetProvider(edge_to_cell_table, EdgeDim, CellDim, 2, has_skip_values=True) +E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim, E2CDim)) + +edge_to_cell_table = np.array( + [ + [0, -1], # edge 0 (neighbours: cell 0) + [2, -1], # edge 1 + [2, -1], # edge 2 + [3, -1], # edge 3 + [4, -1], # edge 4 + [5, -1], # edge 5 + [0, 5], # edge 6 (neighbours: cell 0, cell 5) + [0, 1], # edge 7 + [1, 2], # edge 8 + [1, 3], # edge 9 + [3, 4], # edge 10 + [4, 5], # edge 11 + ] +) + +E2C_offset_provider = gtx.NeighborTableOffsetProvider( + edge_to_cell_table, EdgeDim, CellDim, 2, has_skip_values=True +) + @gtx.field_operator -def sum_adjacent_cells(cells : gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: +def sum_adjacent_cells(cells: gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: # type of cells(E2C) is gtx.Field[[CellDim, E2CDim], float64] return neighbor_sum(cells(E2C), axis=E2CDim) + @gtx.program -def run_sum_adjacent_cells(cells : gtx.Field[[CellDim], float64], out : gtx.Field[[EdgeDim], float64]): +def run_sum_adjacent_cells( + cells: gtx.Field[[CellDim], float64], out: gtx.Field[[EdgeDim], float64] +): sum_adjacent_cells(cells, out=out) + cell_values = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) edge_values = gtx.as_field([EdgeDim], np.zeros((12,))) run_sum_adjacent_cells(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) -print(np.clip(edge_to_cell_table, 0,5)) -ref = cell_values.ndarray[np.clip(edge_to_cell_table, 0,5)] +print(np.clip(edge_to_cell_table, 0, 5)) +ref = cell_values.ndarray[np.clip(edge_to_cell_table, 0, 5)] print(ref) -ref = np.sum(ref, initial=0, where=edge_to_cell_table>=0, axis=1) +ref = np.sum(ref, initial=0, where=edge_to_cell_table >= 0, axis=1) print("sum of adjacent cells: {}".format(edge_values.asnumpy())) From d51cfcac5c8579c32d0018b430ba2e2ee365c636 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 1 Feb 2024 15:04:42 +0100 Subject: [PATCH 03/50] cleanup parts and tests --- src/gt4py/next/embedded/nd_array_field.py | 112 +++++++++--------- tests/next_tests/definitions.py | 1 - .../ffront_tests/test_ffront_fvm_nabla.py | 31 ++--- .../ffront_tests/test_skip_value.py | 73 ------------ .../embedded_tests/test_nd_array_field.py | 84 ++++++++++++- 5 files changed, 152 insertions(+), 149 deletions(-) delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 508cfde9f9..fef2577d0a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -28,6 +28,7 @@ from gt4py.next import common from gt4py.next.embedded import common as embedded_common, context as embedded_context from gt4py.next.ffront import fbuiltins +from gt4py.next.iterator import embedded as itir_embedded try: @@ -390,55 +391,21 @@ def inverse_image( assert common.UnitRange.is_finite(image_range) - # HANNES HACK - # map all negative indices (skip_values) to the smallest positiv index - smallest = xp.min( - self._ndarray, initial=xp.iinfo(xp.int32).max, where=self._ndarray >= 0 + restricted_mask = (self._ndarray >= image_range.start) & ( + self._ndarray < image_range.stop ) - clipped_array = xp.where(self._ndarray < 0, smallest, self._ndarray) - # END HANNES HACK - restricted_mask = (clipped_array >= image_range.start) & ( - self._ndarray < image_range.stop + relative_ranges = _hypercube( + restricted_mask, xp, ignore_mask=self._ndarray == common.SKIP_VALUE ) - # indices of non-zero elements in each dimension - nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(restricted_mask) - - new_dims = [] - non_contiguous_dims = [] - - for i, dim_nnz_indices in enumerate(nnz): - # Check if the indices are contiguous - first_data_index = dim_nnz_indices[0] - assert isinstance(first_data_index, core_defs.INTEGRAL_TYPES) - last_data_index = dim_nnz_indices[-1] - assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES) - indices, counts = xp.unique(dim_nnz_indices, return_counts=True) - dim_range = self._domain[i] - - if len(xp.unique(counts)) == 1 and ( - len(indices) == last_data_index - first_data_index + 1 - ): - idx_offset = dim_range[1].start - start = idx_offset + first_data_index - assert common.is_int_index(start) - stop = idx_offset + last_data_index + 1 - assert common.is_int_index(stop) - new_dims.append( - common.named_range( - ( - dim_range[0], - (start, stop), - ) - ) - ) - else: - non_contiguous_dims.append(dim_range[0]) - if non_contiguous_dims: - raise ValueError( - f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'." - ) + if relative_ranges is None: + raise ValueError("Restriction generates non-contiguous dimensions.") + + new_dims = [ + common.named_range((d, rr + ar.start)) + for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges) + ] self._cache[cache_key] = new_dims @@ -460,6 +427,30 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ __getitem__ = restrict +def _hypercube( + select: core_defs.NDArrayObject, + xp: ModuleType, + ignore_mask: Optional[core_defs.NDArrayObject] = None, +) -> Optional[list[common.UnitRange]]: + """ + Return the hypercube that contains all True values and no False values or `None` if no such hypercube exists. + + If `ignore_mask` is given, the selected values are ignored. + """ + nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select) + + slices = tuple( + slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) for dim_nnz_indices in nnz + ) + hcube = select[tuple(slices)] + if ignore_mask is not None: + hcube |= ignore_mask[tuple(slices)] + if not xp.all(hcube): + return None + + return [common.UnitRange(s.start, s.stop) for s in slices] + + # -- Specialized implementations for builtin operations on array fields -- NdArrayField.register_builtin_func( @@ -491,7 +482,9 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _make_reduction(builtin_name: str, array_builtin_name: str) -> Callable[ +def _make_reduction( + builtin_name: str, array_builtin_name: str, initial_value_op: Callable +) -> Callable[ ..., NdArrayField[common.DimsT, core_defs.ScalarT], ]: @@ -503,23 +496,26 @@ def _builtin_op( if axis not in field.domain.dims: raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.") reduce_dim_index = field.domain.dims.index(axis) - # HANNES HACK current_offset_provider = embedded_context.offset_provider.get(None) assert current_offset_provider is not None offset_definition = current_offset_provider[ axis.value ] # assumes offset and local dimension have same name - # TODO mapping of connectivity to field dimensions + assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) # TODO unclear in case of multiple local dimensions (probably just chain) - # END HANNES HACK new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) + # TODO add test that requires broadcasting + masked_array = field.array_ns.where( + field.array_ns.asarray(offset_definition.table) != common.SKIP_VALUE, + field.ndarray, + initial_value_op(field), + ) + return field.__class__.from_array( getattr(field.array_ns, array_builtin_name)( - field.ndarray, + masked_array, axis=reduce_dim_index, - initial=0, # set proper inital value - where=offset_definition.table >= 0, ), domain=new_domain, ) @@ -528,9 +524,15 @@ def _builtin_op( return _builtin_op -NdArrayField.register_builtin_func(fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum")) -NdArrayField.register_builtin_func(fbuiltins.max_over, _make_reduction("max_over", "max")) -NdArrayField.register_builtin_func(fbuiltins.min_over, _make_reduction("min_over", "min")) +NdArrayField.register_builtin_func( + fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum", lambda x: x.dtype.scalar_type(0)) +) +NdArrayField.register_builtin_func( + fbuiltins.max_over, _make_reduction("max_over", "max", lambda x: x.array_ns.min(x._ndarray)) +) +NdArrayField.register_builtin_func( + fbuiltins.min_over, _make_reduction("min_over", "min", lambda x: x.array_ns.max(x._ndarray)) +) # -- Concrete array implementations -- diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index c95292d702..d9060a2474 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -171,7 +171,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): XFAIL, UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args - (USES_MESH_WITH_SKIP_VALUES, XFAIL, UNSUPPORTED_MESSAGE), ] GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index efbf4ddbc7..c3971d6fe5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -23,7 +23,7 @@ from gt4py.next.program_processors import processor_interface as ppi from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, ) from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, @@ -70,11 +70,8 @@ def pnabla( return compute_pnabla(pp, S_M[0], sign, vol), compute_pnabla(pp, S_M[1], sign, vol) -def test_ffront_compute_zavgS(fieldview_backend): - allocator = fieldview_backend - fieldview_backend = ( - fieldview_backend if hasattr(fieldview_backend, "executor") else None - ) # TODO this pattern is dangerous +def test_ffront_compute_zavgS(exec_alloc_descriptor): + executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator setup = nabla_setup() @@ -87,26 +84,24 @@ def test_ffront_compute_zavgS(fieldview_backend): atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False ) - compute_zavgS.with_backend(fieldview_backend)( - pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} - ) + compute_zavgS.with_backend(executor)(pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v}) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) assert_close(388241977.58389181, np.max(zavgS.asnumpy())) -def test_ffront_nabla(fieldview_backend): - fieldview_backend = fieldview_backend if hasattr(fieldview_backend, "executor") else None +def test_ffront_nabla(exec_alloc_descriptor): + executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator setup = nabla_setup() - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_M = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + sign = gtx.as_field([Vertex, V2EDim], setup.sign_field, allocator=allocator) + pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) + S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) + vol = gtx.as_field([Vertex], setup.vol_field, allocator=allocator) - pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) + pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) + pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) e2v = gtx.NeighborTableOffsetProvider( atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False @@ -115,7 +110,7 @@ def test_ffront_nabla(fieldview_backend): atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 ) - pnabla.with_backend(fieldview_backend)( + pnabla.with_backend(executor)( pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py deleted file mode 100644 index 6657ed1b62..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_skip_value.py +++ /dev/null @@ -1,73 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import numpy as np - -import gt4py.next as gtx -from gt4py.next import float64, neighbor_sum - - -CellDim = gtx.Dimension("Cell") -EdgeDim = gtx.Dimension("Edge") -E2CDim = gtx.Dimension("E2C", kind=gtx.DimensionKind.LOCAL) -E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim, E2CDim)) - -edge_to_cell_table = np.array( - [ - [0, -1], # edge 0 (neighbours: cell 0) - [2, -1], # edge 1 - [2, -1], # edge 2 - [3, -1], # edge 3 - [4, -1], # edge 4 - [5, -1], # edge 5 - [0, 5], # edge 6 (neighbours: cell 0, cell 5) - [0, 1], # edge 7 - [1, 2], # edge 8 - [1, 3], # edge 9 - [3, 4], # edge 10 - [4, 5], # edge 11 - ] -) - -E2C_offset_provider = gtx.NeighborTableOffsetProvider( - edge_to_cell_table, EdgeDim, CellDim, 2, has_skip_values=True -) - - -@gtx.field_operator -def sum_adjacent_cells(cells: gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: - # type of cells(E2C) is gtx.Field[[CellDim, E2CDim], float64] - return neighbor_sum(cells(E2C), axis=E2CDim) - - -@gtx.program -def run_sum_adjacent_cells( - cells: gtx.Field[[CellDim], float64], out: gtx.Field[[EdgeDim], float64] -): - sum_adjacent_cells(cells, out=out) - - -cell_values = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) -edge_values = gtx.as_field([EdgeDim], np.zeros((12,))) - -run_sum_adjacent_cells(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) - -print(np.clip(edge_to_cell_table, 0, 5)) -ref = cell_values.ndarray[np.clip(edge_to_cell_table, 0, 5)] -print(ref) -ref = np.sum(ref, initial=0, where=edge_to_cell_table >= 0, axis=1) - - -print("sum of adjacent cells: {}".format(edge_values.asnumpy())) -print(ref) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 70fa274457..ded3098f58 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -683,14 +683,14 @@ def test_connectivity_field_inverse_image_2d_domain(): codomain=V, ) - # e2c_conn: + # e2v_conn: # ---E2V---- # |[[0 0 2] # E [1 1 2] # | [2 2 2]] # Test contiguous and non-contiguous ranges. - # For the 'e2c_conn' defined above, the only valid range including 2 + # For the 'e2v_conn' defined above, the only valid range including 2 # is [0, 3). Otherwise, the inverse image would be non-contiguous. image_range = UnitRange(V_START, V_STOP) result = e2v_conn.inverse_image(image_range) @@ -742,3 +742,83 @@ def test_connectivity_field_inverse_image_non_contiguous(): with pytest.raises(ValueError, match="generates non-contiguous dimensions"): e2v_conn.inverse_image(UnitRange(V_START, V_STOP)) + + +def test_connectivity_field_inverse_image_2d_domain_skip_values(): + V = Dimension("V") + E = Dimension("E") + E2V = Dimension("E2V") + + V_START, V_STOP = 0, 3 + E_START, E_STOP = 0, 4 + E2V_START, E2V_STOP = 0, 4 + + e2v_conn = common._connectivity( + np.asarray([[-1, 0, 2, -1], [1, 1, 2, 2], [2, 2, -1, -1], [-1, 2, -1, -1]]), + domain=common.domain( + [ + common.named_range((E, (E_START, E_STOP))), + common.named_range((E2V, (E2V_START, E2V_STOP))), + ] + ), + codomain=V, + ) + + # e2v_conn: + # ---E2V--------- + # |[[-1 0 2 -1] + # E [ 1 1 2 2] + # | [ 2 2 -1 -1] + # | [-1 2 -1 -1]] + + image_range = UnitRange(V_START, V_STOP) + result = e2v_conn.inverse_image(image_range) + + assert len(result) == 2 + assert result[0] == (E, UnitRange(E_START, E_STOP)) + assert result[1] == (E2V, UnitRange(E2V_START, E2V_STOP)) + + result = e2v_conn.inverse_image(UnitRange(0, 2)) + assert len(result) == 2 + assert result[0] == (E, UnitRange(0, 2)) + assert result[1] == (E2V, UnitRange(0, 2)) + + result = e2v_conn.inverse_image(UnitRange(0, 1)) + assert len(result) == 2 + assert result[0] == (E, UnitRange(0, 1)) + assert result[1] == (E2V, UnitRange(1, 2)) + + result = e2v_conn.inverse_image(UnitRange(1, 2)) + assert len(result) == 2 + assert result[0] == (E, UnitRange(1, 2)) + assert result[1] == (E2V, UnitRange(0, 2)) + + with pytest.raises(ValueError, match="generates non-contiguous dimensions"): + result = e2v_conn.inverse_image(UnitRange(1, 3)) + + with pytest.raises(ValueError, match="generates non-contiguous dimensions"): + result = e2v_conn.inverse_image(UnitRange(2, 3)) + + +@pytest.mark.parametrize( + "select, ignore_mask, expected", + [ + ([True, True, False], None, [(0, 2)]), + ([True, False, True], None, None), + ([True, False, True], [False, True, False], [(0, 3)]), + ([[False, False, False], [False, True, True]], None, [(1, 2), (1, 3)]), + ( + [[False, True, False], [False, True, True]], + [[False, False, True], [False, False, False]], + [(0, 2), (1, 3)], + ), + ], +) +def test_hypercube(select, ignore_mask, expected): + select = np.asarray(select) + ignore_mask = np.asarray(ignore_mask) if ignore_mask is not None else None + expected = [common.unit_range(e) for e in expected] if expected is not None else None + + result = nd_array_field._hypercube(select, np, ignore_mask=ignore_mask) + + assert result == expected From fdb4423cb2a6cac19417c30e9d65112d78c24d59 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 1 Feb 2024 16:10:41 +0100 Subject: [PATCH 04/50] skip fvm test with no atlas --- .../multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index c3971d6fe5..aeed607b01 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -17,6 +17,9 @@ import numpy as np import pytest + +pytest.importorskip("atlas4py") # isort: skip + from gt4py import next as gtx from gt4py.next import allocators, neighbor_sum from gt4py.next.iterator import atlas_utils From 1e0e2283bd40408323df325ebb933e1185e2dc7b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 2 Feb 2024 11:17:51 +0100 Subject: [PATCH 05/50] testcase which requires broadcasting the mask --- src/gt4py/next/embedded/nd_array_field.py | 12 +++- src/gt4py/next/ffront/decorator.py | 8 +++ .../ffront_tests/test_gt4py_builtins.py | 67 ++++++++++++++----- 3 files changed, 68 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index fef2577d0a..cf70b8bdf5 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -495,6 +495,10 @@ def _builtin_op( raise ValueError("Can only reduce local dimensions.") if axis not in field.domain.dims: raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.") + if len([d for d in field.domain.dims if d.kind is common.DimensionKind.LOCAL]) > 1: + raise NotImplementedError( + "Reducing a field with more than one local dimension is not supported." + ) reduce_dim_index = field.domain.dims.index(axis) current_offset_provider = embedded_context.offset_provider.get(None) assert current_offset_provider is not None @@ -502,12 +506,14 @@ def _builtin_op( axis.value ] # assumes offset and local dimension have same name assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) - # TODO unclear in case of multiple local dimensions (probably just chain) new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) - # TODO add test that requires broadcasting + broadcast_slice = tuple( + slice(None) if d in [axis, offset_definition.origin_axis] else None + for d in field.domain.dims + ) masked_array = field.array_ns.where( - field.array_ns.asarray(offset_definition.table) != common.SKIP_VALUE, + field.array_ns.asarray(offset_definition.table[broadcast_slice]) != common.SKIP_VALUE, field.ndarray, initial_value_op(field), ) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 6510be560e..485f9c9339 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -221,6 +221,10 @@ def __post_init__(self): f"The following closure variables are undefined: {', '.join(undefined_symbols)}." ) + @property + def __name__(self) -> str: + return self.definition.__name__ + @functools.cached_property def __gt_allocator__( self, @@ -601,6 +605,10 @@ def from_function( operator_attributes=operator_attributes, ) + @property + def __name__(self) -> str: + return self.definition.__name__ + def __gt_type__(self) -> ts.CallableType: type_ = self.foast_node.type assert isinstance(type_, ts.CallableType) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e27e73c80d..49df10bcff 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -30,6 +30,7 @@ Joff, KDim, V2EDim, + Vertex, cartesian_case, unstructured_case, ) @@ -85,26 +86,60 @@ def minover(edge_f: cases.EField) -> cases.VField: ) -@pytest.mark.uses_unstructured_shift -def test_reduction_execution(unstructured_case): - @gtx.field_operator - def reduction(edge_f: cases.EField) -> cases.VField: - return neighbor_sum(edge_f(V2E), axis=V2EDim) +@gtx.field_operator +def reduction_e_field(edge_f: cases.EField) -> cases.VField: + return neighbor_sum(edge_f(V2E), axis=V2EDim) + + +@gtx.field_operator +def reduction_ek_field( + edge_f: common.Field[[Edge, KDim], np.int32] +) -> common.Field[[Vertex, KDim], np.int32]: + return neighbor_sum(edge_f(V2E), axis=V2EDim) - @gtx.program - def fencil(edge_f: cases.EField, out: cases.VField): - reduction(edge_f, out=out) +@gtx.field_operator +def reduction_ke_field( + edge_f: common.Field[[KDim, Edge], np.int32] +) -> common.Field[[KDim, Vertex], np.int32]: + return neighbor_sum(edge_f(V2E), axis=V2EDim) + + +@pytest.mark.uses_unstructured_shift +@pytest.mark.parametrize( + "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ +) +def test_neighbor_sum(unstructured_case, fop): v2e_table = unstructured_case.offset_provider["V2E"].table - cases.verify_with_default_data( + + edge_f = cases.allocate(unstructured_case, fop, "edge_f")() + + local_dim_idx = edge_f.domain.dims.index(Edge) + adv_indexing = tuple( + slice(None) if dim is not Edge else v2e_table for dim in edge_f.domain.dims + ) + + broadcast_slice = [] + for dim in edge_f.domain.dims: + if dim is Edge: + broadcast_slice.append(slice(None)) + broadcast_slice.append(slice(None)) + else: + broadcast_slice.append(None) + + broadcasted_table = v2e_table[tuple(broadcast_slice)] + ref = np.sum( + edge_f.asnumpy()[adv_indexing], + axis=local_dim_idx + 1, + initial=0, + where=broadcasted_table != common.SKIP_VALUE, + ) + cases.verify( unstructured_case, - fencil, - ref=lambda edge_f: np.sum( - edge_f[v2e_table], - axis=1, - initial=0, - where=v2e_table != common.SKIP_VALUE, - ), + fop, + edge_f, + out=cases.allocate(unstructured_case, fop, cases.RETURN)(), + ref=ref, ) From 96090e854c33d6d93e943360b89663e1743b0019 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 6 Feb 2024 11:33:45 +0100 Subject: [PATCH 06/50] add comment --- src/gt4py/next/embedded/nd_array_field.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index cf70b8bdf5..63cc0470ba 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -433,9 +433,20 @@ def _hypercube( ignore_mask: Optional[core_defs.NDArrayObject] = None, ) -> Optional[list[common.UnitRange]]: """ - Return the hypercube that contains all True values and no False values or `None` if no such hypercube exists. + Return the hypercube that contains all True values and no False values, or `None` if no such hypercube exists. - If `ignore_mask` is given, the selected values are ignored. + If `ignore_mask` is given, the selected values are ignored. It returns the smallest hypercube. + A bigger hypercube could be constructed by adding lines from the ignore_mask. + Example: + select = True True False + True True False + False False True + + ignore_mask = False False True + False False True + True True True + + would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3]. """ nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select) From 8c408ddc077e537b33f4ac0e1bd134f9aca9855c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 6 Feb 2024 11:45:19 +0100 Subject: [PATCH 07/50] cleanup --- src/gt4py/next/embedded/nd_array_field.py | 4 +- src/gt4py/next/iterator/atlas_utils.py | 4 +- .../iterator_tests/fvm_nabla_setup.py | 212 ------------------ 3 files changed, 6 insertions(+), 214 deletions(-) delete mode 100644 tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/fvm_nabla_setup.py diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 63cc0470ba..0597761e13 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -160,7 +160,9 @@ def from_array( def remap( self: NdArrayField, connectivity: common.ConnectivityField | fbuiltins.FieldOffset ) -> NdArrayField: - # TODO skip values: if the skip_value is -1 we don't need special treatment, we'll just select a random value (the wrapped around one) + # Current implementation relies on SKIP_VALUE == -1: + # if we assume the indexed array has at least one element, we wrap around without out of bounds + assert common.SKIP_VALUE == -1 # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField if not common.is_connectivity_field(connectivity): diff --git a/src/gt4py/next/iterator/atlas_utils.py b/src/gt4py/next/iterator/atlas_utils.py index aad77ccc9a..7b9b60fd75 100644 --- a/src/gt4py/next/iterator/atlas_utils.py +++ b/src/gt4py/next/iterator/atlas_utils.py @@ -17,6 +17,8 @@ except ImportError: IrregularConnectivity = None +from gt4py.next import common + # TODO(tehrengruber): make this a proper Connectivity instead of faking a numpy array class AtlasTable: @@ -29,7 +31,7 @@ def __getitem__(self, indices): if neigh_index < self.atlas_connectivity.cols(primary_index): return self.atlas_connectivity[primary_index, neigh_index] else: - return -1 + return common.SKIP_VALUE else: if neigh_index < 2: return self.atlas_connectivity[primary_index, neigh_index] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/fvm_nabla_setup.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/fvm_nabla_setup.py deleted file mode 100644 index 03e1af27dd..0000000000 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/fvm_nabla_setup.py +++ /dev/null @@ -1,212 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import math - -import numpy as np -from atlas4py import ( - Config, - StructuredGrid, - StructuredMeshGenerator, - Topology, - build_edges, - build_median_dual_mesh, - build_node_to_edge_connectivity, - functionspace, -) - - -def assert_close(expected, actual): - assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual) - - -class nabla_setup: - @staticmethod - def _default_config(): - config = Config() - config["triangulate"] = True - config["angle"] = 20.0 - return config - - def __init__(self, *, grid=StructuredGrid("O32"), config=None): - if config is None: - config = self._default_config() - mesh = StructuredMeshGenerator(config).generate(grid) - - fs_edges = functionspace.EdgeColumns(mesh, halo=1) - fs_nodes = functionspace.NodeColumns(mesh, halo=1) - - build_edges(mesh) - build_node_to_edge_connectivity(mesh) - build_median_dual_mesh(mesh) - - edges_per_node = max( - [mesh.nodes.edge_connectivity.cols(node) for node in range(0, fs_nodes.size)] - ) - - self.mesh = mesh - self.fs_edges = fs_edges - self.fs_nodes = fs_nodes - self.edges_per_node = edges_per_node - - @property - def edges2node_connectivity(self): - return self.mesh.edges.node_connectivity - - @property - def nodes2edge_connectivity(self): - return self.mesh.nodes.edge_connectivity - - @property - def nodes_size(self): - return self.fs_nodes.size - - @property - def edges_size(self): - return self.fs_edges.size - - @staticmethod - def _is_pole_edge(e, edge_flags): - return Topology.check(edge_flags[e], Topology.POLE) - - @property - def is_pole_edge_field(self): - edge_flags = np.array(self.mesh.edges.flags()) - - pole_edge_field = np.zeros((self.edges_size,), dtype=bool) - for e in range(self.edges_size): - pole_edge_field[e] = self._is_pole_edge(e, edge_flags) - return pole_edge_field - - @property - def sign_field(self): - node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) - edge_flags = np.array(self.mesh.edges.flags()) - - for jnode in range(0, self.nodes_size): - node_edge_con = self.mesh.nodes.edge_connectivity - edge_node_con = self.mesh.edges.node_connectivity - for jedge in range(0, node_edge_con.cols(jnode)): - iedge = node_edge_con[jnode, jedge] - ip1 = edge_node_con[iedge, 0] - if jnode == ip1: - node2edge_sign[jnode, jedge] = 1.0 - else: - node2edge_sign[jnode, jedge] = -1.0 - if self._is_pole_edge(iedge, edge_flags): - node2edge_sign[jnode, jedge] = 1.0 - return node2edge_sign - - @property - def S_fields(self): - S = np.array(self.mesh.edges.field("dual_normals"), copy=False) - S_MXX = np.zeros((self.edges_size)) - S_MYY = np.zeros((self.edges_size)) - - MXX = 0 - MYY = 1 - - rpi = 2.0 * math.asin(1.0) - radius = 6371.22e03 - deg2rad = 2.0 * rpi / 360.0 - - for i in range(0, self.edges_size): - S_MXX[i] = S[i, MXX] * radius * deg2rad - S_MYY[i] = S[i, MYY] * radius * deg2rad - - assert math.isclose(min(S_MXX), -103437.60479272791) - assert math.isclose(max(S_MXX), 340115.33913622628) - assert math.isclose(min(S_MYY), -2001577.7946404363) - assert math.isclose(max(S_MYY), 2001577.7946404363) - - return S_MXX, S_MYY - - @property - def vol_field(self): - rpi = 2.0 * math.asin(1.0) - radius = 6371.22e03 - deg2rad = 2.0 * rpi / 360.0 - vol_atlas = np.array(self.mesh.nodes.field("dual_volumes"), copy=False) - # dual_volumes 4.6510228700066421 68.891611253882218 12.347560975609632 - assert_close(4.6510228700066421, min(vol_atlas)) - assert_close(68.891611253882218, max(vol_atlas)) - - vol = np.zeros((vol_atlas.size)) - for i in range(0, vol_atlas.size): - vol[i] = vol_atlas[i] * pow(deg2rad, 2) * pow(radius, 2) - # VOL(min/max): 57510668192.214096 851856184496.32886 - assert_close(57510668192.214096, min(vol)) - assert_close(851856184496.32886, max(vol)) - return vol - - @property - def input_field(self): - klevel = 0 - MXX = 0 - MYY = 1 - rpi = 2.0 * math.asin(1.0) - radius = 6371.22e03 - deg2rad = 2.0 * rpi / 360.0 - - zh0 = 2000.0 - zrad = 3.0 * rpi / 4.0 * radius - zeta = rpi / 16.0 * radius - zlatc = 0.0 - zlonc = 3.0 * rpi / 2.0 - - m_rlonlatcr = self.fs_nodes.create_field( - name="m_rlonlatcr", - levels=1, - dtype=np.float64, - variables=self.edges_per_node, - ) - rlonlatcr = np.array(m_rlonlatcr, copy=False) - - m_rcoords = self.fs_nodes.create_field( - name="m_rcoords", levels=1, dtype=np.float64, variables=self.edges_per_node - ) - rcoords = np.array(m_rcoords, copy=False) - - m_rcosa = self.fs_nodes.create_field(name="m_rcosa", levels=1, dtype=np.float64) - rcosa = np.array(m_rcosa, copy=False) - - m_rsina = self.fs_nodes.create_field(name="m_rsina", levels=1, dtype=np.float64) - rsina = np.array(m_rsina, copy=False) - - m_pp = self.fs_nodes.create_field(name="m_pp", levels=1, dtype=np.float64) - rzs = np.array(m_pp, copy=False) - - rcoords_deg = np.array(self.mesh.nodes.field("lonlat")) - - for jnode in range(0, self.nodes_size): - for i in range(0, 2): - rcoords[jnode, klevel, i] = rcoords_deg[jnode, i] * deg2rad - rlonlatcr[jnode, klevel, i] = rcoords[jnode, klevel, i] # This is not my pattern! - rcosa[jnode, klevel] = math.cos(rlonlatcr[jnode, klevel, MYY]) - rsina[jnode, klevel] = math.sin(rlonlatcr[jnode, klevel, MYY]) - for jnode in range(0, self.nodes_size): - zlon = rlonlatcr[jnode, klevel, MXX] - zdist = math.sin(zlatc) * rsina[jnode, klevel] + math.cos(zlatc) * rcosa[ - jnode, klevel - ] * math.cos(zlon - zlonc) - zdist = radius * math.acos(zdist) - rzs[jnode, klevel] = 0.0 - if zdist < zrad: - rzs[jnode, klevel] = rzs[jnode, klevel] + 0.5 * zh0 * ( - 1.0 + math.cos(rpi * zdist / zrad) - ) * math.pow(math.cos(rpi * zdist / zeta), 2) - - assert_close(0.0000000000000000, min(rzs)) - assert_close(1965.4980340735883, max(rzs)) - return rzs[:, klevel] From b96280e7572f7b3c81487ee556b972bc7a5e03b0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 7 Feb 2024 09:41:51 +0100 Subject: [PATCH 08/50] fix lowering issue past to itir --- src/gt4py/next/ffront/past_to_itir.py | 11 ++++++++++- .../feature_tests/ffront_tests/test_gt4py_builtins.py | 4 ++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index ed239e0436..4feda54b5c 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -207,7 +207,6 @@ def _construct_itir_domain_arg( node_domain: Optional[past.Expr], slices: Optional[list[past.Slice]] = None, ) -> itir.FunCall: - domain_args = [] assert isinstance(out_field.type, ts.TypeSpec) out_field_types = type_info.primitive_constituents(out_field.type).to_list() @@ -222,6 +221,8 @@ def _construct_itir_domain_arg( " caught in type deduction already." ) + domain_args = [] + domain_args_kind = [] for dim_i, dim in enumerate(out_dims): # an expression for the size of a dimension dim_size = itir.SymRef(id=_size_arg_from_field(out_field.id, dim_i)) @@ -247,11 +248,19 @@ def _construct_itir_domain_arg( args=[itir.AxisLiteral(value=dim.value), lower, upper], ) ) + domain_args_kind.append(dim.kind) if self.grid_type == GridType.CARTESIAN: domain_builtin = "cartesian_domain" elif self.grid_type == GridType.UNSTRUCTURED: domain_builtin = "unstructured_domain" + assert len(domain_args) == 2 + # for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical) + if domain_args_kind[0] == DimensionKind.VERTICAL: + assert domain_args_kind[1] == DimensionKind.HORIZONTAL + domain_args[0], domain_args[1] = domain_args[1], domain_args[0] + else: + assert domain_args_kind[1] == DimensionKind.VERTICAL else: raise AssertionError() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 49df10bcff..a35f084677 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -114,7 +114,7 @@ def test_neighbor_sum(unstructured_case, fop): edge_f = cases.allocate(unstructured_case, fop, "edge_f")() - local_dim_idx = edge_f.domain.dims.index(Edge) + local_dim_idx = edge_f.domain.dims.index(Edge) + 1 adv_indexing = tuple( slice(None) if dim is not Edge else v2e_table for dim in edge_f.domain.dims ) @@ -130,7 +130,7 @@ def test_neighbor_sum(unstructured_case, fop): broadcasted_table = v2e_table[tuple(broadcast_slice)] ref = np.sum( edge_f.asnumpy()[adv_indexing], - axis=local_dim_idx + 1, + axis=local_dim_idx, initial=0, where=broadcasted_table != common.SKIP_VALUE, ) From 8669f330f4a27edef2a5ebadab1ca9825bbfeca1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 7 Feb 2024 10:40:56 +0100 Subject: [PATCH 09/50] fix bug --- src/gt4py/next/ffront/past_to_itir.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4feda54b5c..0b967e2736 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -254,13 +254,11 @@ def _construct_itir_domain_arg( domain_builtin = "cartesian_domain" elif self.grid_type == GridType.UNSTRUCTURED: domain_builtin = "unstructured_domain" - assert len(domain_args) == 2 # for no good reason, the domain arguments for unstructured need to be in order (horizontal, vertical) if domain_args_kind[0] == DimensionKind.VERTICAL: + assert len(domain_args) == 2 assert domain_args_kind[1] == DimensionKind.HORIZONTAL domain_args[0], domain_args[1] = domain_args[1], domain_args[0] - else: - assert domain_args_kind[1] == DimensionKind.VERTICAL else: raise AssertionError() From fab118516b339ebe6660d7f31c70c4df6dfb784f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 9 Feb 2024 09:23:02 +0100 Subject: [PATCH 10/50] fix connectivity names --- .../embedded_tests/test_nd_array_field.py | 98 +++++++++---------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index ded3098f58..54dac6b6bf 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -665,60 +665,60 @@ def test_connectivity_field_inverse_image(): def test_connectivity_field_inverse_image_2d_domain(): V = Dimension("V") - E = Dimension("E") - E2V = Dimension("E2V") + C = Dimension("C") + C2V = Dimension("C2V") V_START, V_STOP = 0, 3 - E_START, E_STOP = 0, 3 - E2V_START, E2V_STOP = 0, 3 + C_START, C_STOP = 0, 3 + C2V_START, C2V_STOP = 0, 3 - e2v_conn = common._connectivity( + c2v_conn = common._connectivity( np.asarray([[0, 0, 2], [1, 1, 2], [2, 2, 2]]), domain=common.domain( [ - common.named_range((E, (E_START, E_STOP))), - common.named_range((E2V, (E2V_START, E2V_STOP))), + common.named_range((C, (C_START, C_STOP))), + common.named_range((C2V, (C2V_START, C2V_STOP))), ] ), codomain=V, ) - # e2v_conn: - # ---E2V---- + # c2v_conn: + # ---C2V---- # |[[0 0 2] - # E [1 1 2] + # C [1 1 2] # | [2 2 2]] # Test contiguous and non-contiguous ranges. - # For the 'e2v_conn' defined above, the only valid range including 2 + # For the 'c2v_conn' defined above, the only valid range including 2 # is [0, 3). Otherwise, the inverse image would be non-contiguous. image_range = UnitRange(V_START, V_STOP) - result = e2v_conn.inverse_image(image_range) + result = c2v_conn.inverse_image(image_range) assert len(result) == 2 - assert result[0] == (E, UnitRange(E_START, E_STOP)) - assert result[1] == (E2V, UnitRange(E2V_START, E2V_STOP)) + assert result[0] == (C, UnitRange(C_START, C_STOP)) + assert result[1] == (C2V, UnitRange(C2V_START, C2V_STOP)) - result = e2v_conn.inverse_image(UnitRange(0, 2)) + result = c2v_conn.inverse_image(UnitRange(0, 2)) assert len(result) == 2 - assert result[0] == (E, UnitRange(0, 2)) - assert result[1] == (E2V, UnitRange(0, 2)) + assert result[0] == (C, UnitRange(0, 2)) + assert result[1] == (C2V, UnitRange(0, 2)) - result = e2v_conn.inverse_image(UnitRange(0, 1)) + result = c2v_conn.inverse_image(UnitRange(0, 1)) assert len(result) == 2 - assert result[0] == (E, UnitRange(0, 1)) - assert result[1] == (E2V, UnitRange(0, 2)) + assert result[0] == (C, UnitRange(0, 1)) + assert result[1] == (C2V, UnitRange(0, 2)) - result = e2v_conn.inverse_image(UnitRange(1, 2)) + result = c2v_conn.inverse_image(UnitRange(1, 2)) assert len(result) == 2 - assert result[0] == (E, UnitRange(1, 2)) - assert result[1] == (E2V, UnitRange(0, 2)) + assert result[0] == (C, UnitRange(1, 2)) + assert result[1] == (C2V, UnitRange(0, 2)) with pytest.raises(ValueError, match="generates non-contiguous dimensions"): - result = e2v_conn.inverse_image(UnitRange(1, 3)) + result = c2v_conn.inverse_image(UnitRange(1, 3)) with pytest.raises(ValueError, match="generates non-contiguous dimensions"): - result = e2v_conn.inverse_image(UnitRange(2, 3)) + result = c2v_conn.inverse_image(UnitRange(2, 3)) def test_connectivity_field_inverse_image_non_contiguous(): @@ -746,58 +746,58 @@ def test_connectivity_field_inverse_image_non_contiguous(): def test_connectivity_field_inverse_image_2d_domain_skip_values(): V = Dimension("V") - E = Dimension("E") - E2V = Dimension("E2V") + C = Dimension("C") + C2V = Dimension("C2V") V_START, V_STOP = 0, 3 - E_START, E_STOP = 0, 4 - E2V_START, E2V_STOP = 0, 4 + C_START, C_STOP = 0, 4 + C2V_START, C2V_STOP = 0, 4 - e2v_conn = common._connectivity( + c2v_conn = common._connectivity( np.asarray([[-1, 0, 2, -1], [1, 1, 2, 2], [2, 2, -1, -1], [-1, 2, -1, -1]]), domain=common.domain( [ - common.named_range((E, (E_START, E_STOP))), - common.named_range((E2V, (E2V_START, E2V_STOP))), + common.named_range((C, (C_START, C_STOP))), + common.named_range((C2V, (C2V_START, C2V_STOP))), ] ), codomain=V, ) - # e2v_conn: - # ---E2V--------- + # c2v_conn: + # ---C2V--------- # |[[-1 0 2 -1] - # E [ 1 1 2 2] + # C [ 1 1 2 2] # | [ 2 2 -1 -1] # | [-1 2 -1 -1]] image_range = UnitRange(V_START, V_STOP) - result = e2v_conn.inverse_image(image_range) + result = c2v_conn.inverse_image(image_range) assert len(result) == 2 - assert result[0] == (E, UnitRange(E_START, E_STOP)) - assert result[1] == (E2V, UnitRange(E2V_START, E2V_STOP)) + assert result[0] == (C, UnitRange(C_START, C_STOP)) + assert result[1] == (C2V, UnitRange(C2V_START, C2V_STOP)) - result = e2v_conn.inverse_image(UnitRange(0, 2)) + result = c2v_conn.inverse_image(UnitRange(0, 2)) assert len(result) == 2 - assert result[0] == (E, UnitRange(0, 2)) - assert result[1] == (E2V, UnitRange(0, 2)) + assert result[0] == (C, UnitRange(0, 2)) + assert result[1] == (C2V, UnitRange(0, 2)) - result = e2v_conn.inverse_image(UnitRange(0, 1)) + result = c2v_conn.inverse_image(UnitRange(0, 1)) assert len(result) == 2 - assert result[0] == (E, UnitRange(0, 1)) - assert result[1] == (E2V, UnitRange(1, 2)) + assert result[0] == (C, UnitRange(0, 1)) + assert result[1] == (C2V, UnitRange(1, 2)) - result = e2v_conn.inverse_image(UnitRange(1, 2)) + result = c2v_conn.inverse_image(UnitRange(1, 2)) assert len(result) == 2 - assert result[0] == (E, UnitRange(1, 2)) - assert result[1] == (E2V, UnitRange(0, 2)) + assert result[0] == (C, UnitRange(1, 2)) + assert result[1] == (C2V, UnitRange(0, 2)) with pytest.raises(ValueError, match="generates non-contiguous dimensions"): - result = e2v_conn.inverse_image(UnitRange(1, 3)) + result = c2v_conn.inverse_image(UnitRange(1, 3)) with pytest.raises(ValueError, match="generates non-contiguous dimensions"): - result = e2v_conn.inverse_image(UnitRange(2, 3)) + result = c2v_conn.inverse_image(UnitRange(2, 3)) @pytest.mark.parametrize( From d24f6a3c11ebc0a8f8ba7307823a1f5d938725d0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 9 Feb 2024 09:24:40 +0100 Subject: [PATCH 11/50] explicit xp.newaxis --- src/gt4py/next/embedded/nd_array_field.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 0597761e13..e103b06171 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -504,6 +504,8 @@ def _make_reduction( def _builtin_op( field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension ) -> NdArrayField[common.DimsT, core_defs.ScalarT]: + xp = field.array_ns + if not axis.kind == common.DimensionKind.LOCAL: raise ValueError("Can only reduce local dimensions.") if axis not in field.domain.dims: @@ -522,17 +524,17 @@ def _builtin_op( new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) broadcast_slice = tuple( - slice(None) if d in [axis, offset_definition.origin_axis] else None + slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis for d in field.domain.dims ) - masked_array = field.array_ns.where( - field.array_ns.asarray(offset_definition.table[broadcast_slice]) != common.SKIP_VALUE, + masked_array = xp.where( + xp.asarray(offset_definition.table[broadcast_slice]) != common.SKIP_VALUE, field.ndarray, initial_value_op(field), ) return field.__class__.from_array( - getattr(field.array_ns, array_builtin_name)( + getattr(xp, array_builtin_name)( masked_array, axis=reduce_dim_index, ), From dcf17e2e18a331e98de2bfbb4ce1957956cbe47f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 9 Feb 2024 09:45:43 +0100 Subject: [PATCH 12/50] wrap the mask hypercube --- src/gt4py/next/common.py | 6 ++++ src/gt4py/next/embedded/nd_array_field.py | 43 ++++++++++++++++------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 90e76d671d..bc3500f071 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -746,6 +746,12 @@ def kind(self) -> ConnectivityKind: @abc.abstractmethod def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: ... + @property + def skip_value(self) -> core_defs.IntegralScalar: + # TODO(havogt): This is a preparation for the future, currently we assume the skip_value is + # globally defined to be `-1`. In the future we want to make this customizable in the connectivity. + return SKIP_VALUE + # Operators def __abs__(self) -> Never: raise TypeError("'ConnectivityField' does not support this operation.") diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e103b06171..9a513410c2 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -393,13 +393,7 @@ def inverse_image( assert common.UnitRange.is_finite(image_range) - restricted_mask = (self._ndarray >= image_range.start) & ( - self._ndarray < image_range.stop - ) - - relative_ranges = _hypercube( - restricted_mask, xp, ignore_mask=self._ndarray == common.SKIP_VALUE - ) + relative_ranges = _hypercube(self._ndarray, image_range, xp, self.skip_value) if relative_ranges is None: raise ValueError("Restriction generates non-contiguous dimensions.") @@ -430,7 +424,30 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ def _hypercube( - select: core_defs.NDArrayObject, + index_array: core_defs.NDArrayObject, + image_range: common.UnitRange, + xp: ModuleType, + skip_value: Optional[core_defs.IntegralScalar] = None, +) -> Optional[list[common.UnitRange]]: + """ + Return the hypercube that contains all indices in `index_array` that are within `image_range`, or `None` if no such hypercube exists. + + If `skip_value` is given, the selected values are ignored. It returns the smallest hypercube. + A bigger hypercube could be constructed by adding lines that contain only `skip_value`s. + Example: + index_array = 0 1 -1 + 3 4 -1 + -1 -1 -1 + skip_value = -1 + would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3]. + """ + restricted_mask = (index_array >= image_range.start) & (index_array < image_range.stop) + ignore_mask = None if skip_value is None else index_array == skip_value + return _hypercube_from_mask(restricted_mask, xp, ignore_mask) + + +def _hypercube_from_mask( + select_mask: core_defs.NDArrayObject, xp: ModuleType, ignore_mask: Optional[core_defs.NDArrayObject] = None, ) -> Optional[list[common.UnitRange]]: @@ -440,9 +457,9 @@ def _hypercube( If `ignore_mask` is given, the selected values are ignored. It returns the smallest hypercube. A bigger hypercube could be constructed by adding lines from the ignore_mask. Example: - select = True True False - True True False - False False True + select = True True False + True True False + False False False ignore_mask = False False True False False True @@ -450,12 +467,12 @@ def _hypercube( would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3]. """ - nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select) + nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select_mask) slices = tuple( slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) for dim_nnz_indices in nnz ) - hcube = select[tuple(slices)] + hcube = select_mask[tuple(slices)] if ignore_mask is not None: hcube |= ignore_mask[tuple(slices)] if not xp.all(hcube): From 3804eb6a958d1e1ba18f92eef151b0437e8f5bad Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 9 Feb 2024 09:55:12 +0100 Subject: [PATCH 13/50] prepare configurable skip_value --- src/gt4py/next/common.py | 12 +++++++----- src/gt4py/next/embedded/nd_array_field.py | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index bc3500f071..ed2263bbee 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -729,7 +729,7 @@ class ConnectivityKind(enum.Flag): @extended_runtime_checkable -# type: ignore[misc] # DimT should be covariant, but break in another place +# type: ignore[misc] # DimT should be covariant, but breaks in another place class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod @@ -747,10 +747,8 @@ def kind(self) -> ConnectivityKind: def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: ... @property - def skip_value(self) -> core_defs.IntegralScalar: - # TODO(havogt): This is a preparation for the future, currently we assume the skip_value is - # globally defined to be `-1`. In the future we want to make this customizable in the connectivity. - return SKIP_VALUE + @abc.abstractmethod + def skip_value(self) -> Optional[core_defs.IntegralScalar]: ... # Operators def __abs__(self) -> Never: @@ -918,6 +916,10 @@ def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: def codomain(self) -> DimT: return self.dimension + @property + def skip_value(self) -> None: + return None + @functools.cached_property def kind(self) -> ConnectivityKind: return ConnectivityKind(0) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9a513410c2..b2712b8017 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -317,6 +317,7 @@ class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ NdArrayField[common.DimsT, core_defs.IntegralScalar], ): _codomain: common.DimT + _skip_value: Optional[core_defs.IntegralScalar] @functools.cached_property def _cache(self) -> dict: @@ -331,6 +332,10 @@ def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ig def codomain(self) -> common.DimT: return self._codomain + @property + def skip_value(self) -> Optional[core_defs.IntegralScalar]: + return self._skip_value + @functools.cached_property def kind(self) -> common.ConnectivityKind: kind = common.ConnectivityKind.MODIFY_STRUCTURE @@ -369,7 +374,12 @@ def from_array( # type: ignore[override] assert isinstance(codomain, common.Dimension) - return cls(domain, array, codomain) + return cls( + domain, + array, + codomain, + _skip_value=common.SKIP_VALUE, # TODO(havogt): make skip_value configurable + ) def inverse_image( self, image_range: common.UnitRange | common.NamedRange @@ -415,7 +425,7 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ xp = cls.array_ns new_domain, buffer_slice = self._slice(index) new_buffer = xp.asarray(self.ndarray[buffer_slice]) - restricted_connectivity = cls(new_domain, new_buffer, self.codomain) + restricted_connectivity = cls(new_domain, new_buffer, self.codomain, self.skip_value) self._cache[cache_key] = restricted_connectivity return restricted_connectivity From e4a6f39618adb0a4c070e02d60f2092ceb93506d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 9 Feb 2024 12:05:19 +0100 Subject: [PATCH 14/50] fix test --- .../next_tests/unit_tests/embedded_tests/test_nd_array_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 54dac6b6bf..259367100e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -819,6 +819,6 @@ def test_hypercube(select, ignore_mask, expected): ignore_mask = np.asarray(ignore_mask) if ignore_mask is not None else None expected = [common.unit_range(e) for e in expected] if expected is not None else None - result = nd_array_field._hypercube(select, np, ignore_mask=ignore_mask) + result = nd_array_field._hypercube_from_mask(select, np, ignore_mask=ignore_mask) assert result == expected From 1f9ac8506ee582c6c6f4b502c0c73c864b16a96c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 9 Feb 2024 13:56:23 +0100 Subject: [PATCH 15/50] skip value refactoring --- src/gt4py/next/common.py | 1 + src/gt4py/next/constructors.py | 6 +++- src/gt4py/next/embedded/nd_array_field.py | 32 +++---------------- src/gt4py/next/ffront/fbuiltins.py | 1 + .../embedded_tests/test_nd_array_field.py | 24 +++++++------- 5 files changed, 25 insertions(+), 39 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index ed2263bbee..64267eed99 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -841,6 +841,7 @@ def _connectivity( *, domain: Optional[DomainLike] = None, dtype: Optional[core_defs.DType] = None, + skip_value: Optional[core_defs.IntegralScalar] = None, ) -> ConnectivityField: raise NotImplementedError diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 8b41bf7cba..6628d56e24 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -308,6 +308,7 @@ def as_connectivity( *, allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, device: Optional[core_defs.Device] = None, + skip_value: Optional[core_defs.IntegralScalar] = None, # copy=False, TODO ) -> common.ConnectivityField: """ @@ -330,6 +331,9 @@ def as_connectivity( Raises: ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape. """ + assert ( + skip_value is None or skip_value == common.SKIP_VALUE + ) # TODO(havogt): not yet configurable if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: @@ -359,7 +363,7 @@ def as_connectivity( # TODO(havogt): consider adding MutableNDArrayObject buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] connectivity_field = common._connectivity( - buffer.ndarray, codomain=codomain, domain=actual_domain + buffer.ndarray, codomain=codomain, domain=actual_domain, skip_value=skip_value ) assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index b2712b8017..0fd7334c04 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -356,6 +356,7 @@ def from_array( # type: ignore[override] *, domain: common.DomainLike, dtype: Optional[core_defs.DTypeLike] = None, + skip_value: Optional[core_defs.IntegralScalar] = None, ) -> NdArrayConnectivityField: domain = common.domain(domain) xp = cls.array_ns @@ -378,7 +379,7 @@ def from_array( # type: ignore[override] domain, array, codomain, - _skip_value=common.SKIP_VALUE, # TODO(havogt): make skip_value configurable + _skip_value=skip_value, ) def inverse_image( @@ -451,39 +452,16 @@ def _hypercube( skip_value = -1 would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3]. """ - restricted_mask = (index_array >= image_range.start) & (index_array < image_range.stop) - ignore_mask = None if skip_value is None else index_array == skip_value - return _hypercube_from_mask(restricted_mask, xp, ignore_mask) + select_mask = (index_array >= image_range.start) & (index_array < image_range.stop) - -def _hypercube_from_mask( - select_mask: core_defs.NDArrayObject, - xp: ModuleType, - ignore_mask: Optional[core_defs.NDArrayObject] = None, -) -> Optional[list[common.UnitRange]]: - """ - Return the hypercube that contains all True values and no False values, or `None` if no such hypercube exists. - - If `ignore_mask` is given, the selected values are ignored. It returns the smallest hypercube. - A bigger hypercube could be constructed by adding lines from the ignore_mask. - Example: - select = True True False - True True False - False False False - - ignore_mask = False False True - False False True - True True True - - would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3]. - """ nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select_mask) slices = tuple( slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) for dim_nnz_indices in nnz ) hcube = select_mask[tuple(slices)] - if ignore_mask is not None: + if skip_value is not None: + ignore_mask = index_array == skip_value hcube |= ignore_mask[tuple(slices)] if not xp.all(hcube): return None diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index fdf900f76f..2fb58ad699 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -365,6 +365,7 @@ def as_connectivity_field(self): codomain=self.source, data=offset_definition.table, dtype=offset_definition.index_type, + skip_value=common.SKIP_VALUE if offset_definition.has_skip_values else None, ) else: raise NotImplementedError() diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 259367100e..2e10db2270 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -762,6 +762,7 @@ def test_connectivity_field_inverse_image_2d_domain_skip_values(): ] ), codomain=V, + skip_value=-1, ) # c2v_conn: @@ -801,24 +802,25 @@ def test_connectivity_field_inverse_image_2d_domain_skip_values(): @pytest.mark.parametrize( - "select, ignore_mask, expected", + "index_array, expected", [ - ([True, True, False], None, [(0, 2)]), - ([True, False, True], None, None), - ([True, False, True], [False, True, False], [(0, 3)]), - ([[False, False, False], [False, True, True]], None, [(1, 2), (1, 3)]), + ([0, 0, 1], [(0, 2)]), + ([0, 1, 0], None), + ([0, -1, 0], [(0, 3)]), + ([[1, 1, 1], [1, 0, 0]], [(1, 2), (1, 3)]), ( - [[False, True, False], [False, True, True]], - [[False, False, True], [False, False, False]], + [[1, 0, -1], [1, 0, 0]], [(0, 2), (1, 3)], ), ], ) -def test_hypercube(select, ignore_mask, expected): - select = np.asarray(select) - ignore_mask = np.asarray(ignore_mask) if ignore_mask is not None else None +def test_hypercube(index_array, expected): + index_array = np.asarray(index_array) + image_range = common.UnitRange(0, 1) + skip_value = -1 + expected = [common.unit_range(e) for e in expected] if expected is not None else None - result = nd_array_field._hypercube_from_mask(select, np, ignore_mask=ignore_mask) + result = nd_array_field._hypercube(index_array, image_range, np, skip_value) assert result == expected From 07cffa67829f42d63767df31c419c2fb24d182b6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 9 Feb 2024 20:12:43 +0100 Subject: [PATCH 16/50] fix skip_value check --- src/gt4py/next/embedded/nd_array_field.py | 9 ++++----- .../feature_tests/ffront_tests/ffront_test_utils.py | 12 ++++++------ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 0fd7334c04..4678b3b5d5 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -160,17 +160,16 @@ def from_array( def remap( self: NdArrayField, connectivity: common.ConnectivityField | fbuiltins.FieldOffset ) -> NdArrayField: - # Current implementation relies on SKIP_VALUE == -1: - # if we assume the indexed array has at least one element, we wrap around without out of bounds - assert common.SKIP_VALUE == -1 - # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField if not common.is_connectivity_field(connectivity): assert isinstance(connectivity, fbuiltins.FieldOffset) connectivity = connectivity.as_connectivity_field() - assert common.is_connectivity_field(connectivity) + # Current implementation relies on skip_value == -1: + # if we assume the indexed array has at least one element, we wrap around without out of bounds + assert connectivity.skip_value == -1 + # Compute the new domain dim = connectivity.codomain dim_idx = self.domain.dim_index(dim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index d8c4696073..8aaba85503 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -239,12 +239,12 @@ def skip_value_mesh() -> MeshDescriptor: v2e_arr = np.array( [ - [1, 8, 7, 0, -1], - [2, 8, 1, -1, -1], - [3, 9, 8, 2, -1], - [4, 10, 3, -1, -1], - [5, 11, 4, -1, -1], - [0, 6, 4, -1, -1], + [1, 8, 7, 0, common.SKIP_VALUE], + [2, 8, 1, common.SKIP_VALUE, common.SKIP_VALUE], + [3, 9, 8, 2, common.SKIP_VALUE], + [4, 10, 3, common.SKIP_VALUE, common.SKIP_VALUE], + [5, 11, 4, common.SKIP_VALUE, common.SKIP_VALUE], + [0, 6, 4, common.SKIP_VALUE, common.SKIP_VALUE], [6, 7, 9, 10, 11], ], dtype=gtx.IndexType, From 93bf8893699ba3ee9610d134640c9c5f5cea9cb1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Sat, 10 Feb 2024 09:32:10 +0100 Subject: [PATCH 17/50] fix assert --- src/gt4py/next/embedded/nd_array_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 4678b3b5d5..57f9d7783c 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -168,7 +168,7 @@ def remap( # Current implementation relies on skip_value == -1: # if we assume the indexed array has at least one element, we wrap around without out of bounds - assert connectivity.skip_value == -1 + assert connectivity.skip_value is None or connectivity.skip_value == -1 # Compute the new domain dim = connectivity.codomain From ce6adde387672262610f5e6712576b9e2d991579 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 14 Feb 2024 13:05:43 +0100 Subject: [PATCH 18/50] prototype --- src/gt4py/next/embedded/nd_array_field.py | 96 ++++++++++++++++--- src/gt4py/next/embedded/operators.py | 13 ++- src/gt4py/next/ffront/fbuiltins.py | 11 +++ .../ffront/foast_passes/type_deduction.py | 8 +- src/gt4py/next/ffront/foast_to_itir.py | 6 +- 5 files changed, 112 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 5e6e4a5bf5..64c30e6f28 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -44,8 +44,12 @@ def _make_builtin(builtin_name: str, array_builtin_name: str) -> Callable[..., NdArrayField]: def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: - first = fields[0] - assert isinstance(first, NdArrayField) + first = None + for f in fields: + if isinstance(f, NdArrayField): + first = f + break + assert first is not None xp = first.__class__.array_ns op = getattr(xp, array_builtin_name) @@ -411,10 +415,7 @@ def inverse_image( if relative_ranges is None: raise ValueError("Restriction generates non-contiguous dimensions.") - new_dims = [ - common.named_range((d, rr + ar.start)) - for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges) - ] + new_dims = _relative_ranges_to_domain(relative_ranges, self.domain) self._cache[cache_key] = new_dims @@ -436,6 +437,14 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field: __getitem__ = restrict +def _relative_ranges_to_domain( + relative_ranges: tuple[common.UnitRange, ...], domain: common.Domain +) -> common.Domain: + return common.Domain( + dims=domain.dims, ranges=[rr + ar.start for ar, rr in zip(domain.ranges, relative_ranges)] + ) + + def _hypercube( index_array: core_defs.NDArrayObject, image_range: common.UnitRange, @@ -455,15 +464,24 @@ def _hypercube( would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3]. """ select_mask = (index_array >= image_range.start) & (index_array < image_range.stop) + ignore_mask = None if skip_value is None else index_array == skip_value + return _hypercube_from_mask(select_mask, xp, ignore_mask) - nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select_mask) +def _hypercube_from_mask( + select_mask: core_defs.NDArrayObject, + xp: ModuleType, + ignore_mask: Optional[core_defs.NDArrayObject] = None, +) -> list[common.UnitRange]: + nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(select_mask) slices = tuple( - slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) for dim_nnz_indices in nnz + slice(xp.min(dim_nnz_indices), xp.max(dim_nnz_indices) + 1) + if dim_nnz_indices.size > 0 + else slice(0, 0) # TODO test this code path + for dim_nnz_indices in nnz ) hcube = select_mask[tuple(slices)] - if skip_value is not None: - ignore_mask = index_array == skip_value + if ignore_mask is not None: hcube |= ignore_mask[tuple(slices)] if not xp.all(hcube): return None @@ -502,12 +520,62 @@ def _hypercube( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) +def _concat_where(m, t, f): + xp = t.__class__.array_ns + if m.domain.ndim != 1: + raise ValueError("Can only concatenate fields with a 1-dimensional mask.") + mask_dim = m.domain.dims[0] + mask_true_domain = _relative_ranges_to_domain(_hypercube_from_mask(m.ndarray, xp), m.domain) + t_domain = t.domain if common.is_field(t) else common.Domain() + true_domain = embedded_common.intersect_domains(t_domain, mask_true_domain) + mask_false_domain = _relative_ranges_to_domain(_hypercube_from_mask(~m.ndarray, xp), m.domain) + f_domain = f.domain if common.is_field(f) else common.Domain() + false_domain = embedded_common.intersect_domains(f_domain, mask_false_domain) + + if common.is_field(t): + true_slices = _get_slices_from_domain_slice(t.domain, true_domain) + true_transformed = xp.asarray(t.ndarray[true_slices]) + else: + true_transformed = t + if common.is_field(f): + false_slices = _get_slices_from_domain_slice(f.domain, false_domain) + false_transformed = xp.asarray(f.ndarray[false_slices]) + else: + false_transformed = f + + assert true_domain.dims == false_domain.dims # TODO implement broadcasting + mask_index = true_domain.dim_index(mask_dim) + if true_domain[mask_dim][1].stop == false_domain[mask_dim][1].start: + result = xp.concatenate((true_transformed, false_transformed), axis=mask_index) + result_domain = true_domain.replace( + mask_dim, + ( + mask_dim, + common.UnitRange(true_domain[mask_dim][1].start, false_domain[mask_dim][1].stop), + ), + ) + elif false_domain[mask_dim][1].stop == true_domain[mask_dim][1].start: + result = xp.concatenate((false_transformed, true_transformed), axis=mask_index) + result_domain = true_domain.replace( + mask_dim, + ( + mask_dim, + common.UnitRange(false_domain[mask_dim][1].start, true_domain[mask_dim][1].stop), + ), + ) + # concatenate false then true + else: + raise ValueError("Mask does not split the domain into two non-overlapping parts.") + + return t.__class__.from_array(result, domain=result_domain) + + +NdArrayField.register_builtin_func(fbuiltins.concat_where, _concat_where) + + def _make_reduction( builtin_name: str, array_builtin_name: str, initial_value_op: Callable -) -> Callable[ - ..., - NdArrayField[common.DimsT, core_defs.ScalarT], -]: +) -> Callable[..., NdArrayField[common.DimsT, core_defs.ScalarT],]: def _builtin_op( field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension ) -> NdArrayField[common.DimsT, core_defs.ScalarT]: diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index fc3ccda335..62b683aac6 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -79,6 +79,14 @@ def scan_loop(hpos): return res +def _get_out_domain( + out: common.MutableField | tuple[common.MutableField | tuple, ...] +) -> common.Domain: + return embedded_common.intersect_domains( + *[f.domain for f in utils.flatten_nested_tuple((out,))] + ) + + def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): if "out" in kwargs: # called from program or direct field_operator as program @@ -98,10 +106,7 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): domain = kwargs.pop("domain", None) - flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,)) - assert all(f.domain == flattened_out[0].domain for f in flattened_out) - - out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain + out_domain = common.domain(domain) if domain is not None else _get_out_domain(out) new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 2fb58ad699..585ec718ac 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -202,6 +202,16 @@ def where( raise NotImplementedError() +@WhereBuiltinFunction +def concat_where( + mask: common.Field, + true_field: common.Field | core_defs.ScalarT | Tuple, + false_field: common.Field | core_defs.ScalarT | Tuple, + /, +) -> common.Field | Tuple: + raise NotImplementedError() + + @BuiltInFunction def astype( value: common.Field | core_defs.ScalarT | Tuple, @@ -295,6 +305,7 @@ def impl( "min_over", "broadcast", "where", + "concat_where", "astype", "as_offset", ] + MATH_BUILTIN_NAMES diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 64fea7935c..64bdc38e73 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -438,9 +438,9 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: f"got types '{true_type}' and '{false_type}.", ) # TODO: properly patch symtable (new node?) - symtable[sym].type = new_node.annex.propagated_symbols[sym].type = ( - new_true_branch.annex.symtable[sym].type - ) + symtable[sym].type = new_node.annex.propagated_symbols[ + sym + ].type = new_true_branch.annex.symtable[sym].type return new_node @@ -931,6 +931,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: location=node.location, ) + _visit_concat_where = _visit_where + def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index c0e618a42d..d4fbba8079 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -326,6 +326,9 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map("if_", *node.args) + def _visit_concat_where(self, node: foast.Call, **kwargs) -> itir.FunCall: + return self._map("if_", *node.args) + def _visit_broadcast(self, node: foast.Call, **kwargs) -> itir.FunCall: return self.visit(node.args[0], **kwargs) @@ -430,4 +433,5 @@ def _process_elements( return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj) -class FieldOperatorLoweringError(Exception): ... +class FieldOperatorLoweringError(Exception): + ... From 64626106ad89772a10f300e9e8eae1bd9eb4d7b6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 14 Feb 2024 20:56:57 +0100 Subject: [PATCH 19/50] fix bug in reverse sub and div --- src/gt4py/next/embedded/nd_array_field.py | 24 ++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 64c30e6f28..17e576ff8c 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -42,7 +42,9 @@ jnp: Optional[ModuleType] = None # type:ignore[no-redef] -def _make_builtin(builtin_name: str, array_builtin_name: str) -> Callable[..., NdArrayField]: +def _make_builtin( + builtin_name: str, array_builtin_name: str, reversed=False +) -> Callable[..., NdArrayField]: def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: first = None for f in fields: @@ -63,7 +65,11 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: if f.domain == domain_intersection: transformed.append(xp.asarray(f.ndarray)) else: - f_broadcasted = _broadcast(f, domain_intersection.dims) + f_broadcasted = ( + _broadcast(f, domain_intersection.dims) + if f.domain.dims != domain_intersection.dims + else f + ) f_slices = _get_slices_from_domain_slice( f_broadcasted.domain, domain_intersection ) @@ -71,7 +77,8 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: else: assert core_defs.is_scalar_type(f) transformed.append(f) - + if reversed: + transformed = transformed[::-1] new_data = op(*transformed) return first.__class__.from_array(new_data, domain=domain_intersection) @@ -252,13 +259,16 @@ def __setitem__( __pos__ = _make_builtin("pos", "positive") - __sub__ = __rsub__ = _make_builtin("sub", "subtract") + __sub__ = _make_builtin("sub", "subtract") + __rsub__ = _make_builtin("sub", "subtract", reversed=True) __mul__ = __rmul__ = _make_builtin("mul", "multiply") - __truediv__ = __rtruediv__ = _make_builtin("div", "divide") + __truediv__ = _make_builtin("div", "divide") + __rtruediv__ = _make_builtin("div", "divide", reversed=True) - __floordiv__ = __rfloordiv__ = _make_builtin("floordiv", "floor_divide") + __floordiv__ = _make_builtin("floordiv", "floor_divide") + __rfloordiv__ = _make_builtin("floordiv", "floor_divide", reversed=True) __pow__ = _make_builtin("pow", "power") @@ -694,7 +704,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] domain_slice.append(slice(None)) named_ranges.append((dim, field.domain[pos][1])) else: - domain_slice.append(np.newaxis) + domain_slice.append(field.__class__.array_ns.newaxis) named_ranges.append((dim, common.UnitRange.infinite())) return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) From 5c924910391e03b0c271b7a1a8246ca37dadd3b1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 14 Feb 2024 23:12:09 +0100 Subject: [PATCH 20/50] alternative concat_where that deals with multiple ranges --- src/gt4py/next/embedded/nd_array_field.py | 152 ++++++++++++++++------ 1 file changed, 114 insertions(+), 38 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 17e576ff8c..636acb8705 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -530,56 +530,132 @@ def _hypercube_from_mask( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) +def _mask_ranges(mask: core_defs.NDArrayObject) -> list[tuple[bool, common.UnitRange]]: + # TODO: does it make sense to upgrade this naive algorithm to numpy? + assert mask.ndim == 1 + cur = mask[0] + ind = 0 + res = [] + for i in range(1, mask.shape[0]): + if mask[i] != cur: + res.append((cur, common.UnitRange(ind, i))) + cur = mask[i] + ind = i + res.append((cur, common.UnitRange(ind, mask.shape[0]))) + return res + + +def _trim(lst: list[tuple[bool, common.Domain]]) -> list[tuple[bool, common.Domain]]: + res = [] + found = False + for v, d in lst: + if d != common.Domain(): + res.append((v, d)) + found = True + else: + if found: + raise ValueError("Out of bounds.") + return res + + def _concat_where(m, t, f): xp = t.__class__.array_ns if m.domain.ndim != 1: raise ValueError("Can only concatenate fields with a 1-dimensional mask.") mask_dim = m.domain.dims[0] - mask_true_domain = _relative_ranges_to_domain(_hypercube_from_mask(m.ndarray, xp), m.domain) + + relative_mask_ranges = _mask_ranges(m.ndarray) + mask_domains = [ + (v, _relative_ranges_to_domain((mr,), m.domain)) for v, mr in relative_mask_ranges + ] t_domain = t.domain if common.is_field(t) else common.Domain() - true_domain = embedded_common.intersect_domains(t_domain, mask_true_domain) - mask_false_domain = _relative_ranges_to_domain(_hypercube_from_mask(~m.ndarray, xp), m.domain) f_domain = f.domain if common.is_field(f) else common.Domain() - false_domain = embedded_common.intersect_domains(f_domain, mask_false_domain) - - if common.is_field(t): - true_slices = _get_slices_from_domain_slice(t.domain, true_domain) - true_transformed = xp.asarray(t.ndarray[true_slices]) - else: - true_transformed = t - if common.is_field(f): - false_slices = _get_slices_from_domain_slice(f.domain, false_domain) - false_transformed = xp.asarray(f.ndarray[false_slices]) - else: - false_transformed = f - - assert true_domain.dims == false_domain.dims # TODO implement broadcasting - mask_index = true_domain.dim_index(mask_dim) - if true_domain[mask_dim][1].stop == false_domain[mask_dim][1].start: - result = xp.concatenate((true_transformed, false_transformed), axis=mask_index) - result_domain = true_domain.replace( - mask_dim, - ( - mask_dim, - common.UnitRange(true_domain[mask_dim][1].start, false_domain[mask_dim][1].stop), - ), - ) - elif false_domain[mask_dim][1].stop == true_domain[mask_dim][1].start: - result = xp.concatenate((false_transformed, true_transformed), axis=mask_index) - result_domain = true_domain.replace( + intersected_domains = [ + (v, embedded_common.intersect_domains(t_domain if v else f_domain, d)) + for v, d in mask_domains + ] + + assert t.domain.dims == f.domain.dims # TODO implement broadcasting + + transformed = [] + for v, d in _trim(intersected_domains): + if v: + if common.is_field(t): + slices = _get_slices_from_domain_slice(t.domain, d) + transformed.append(xp.asarray(t.ndarray[slices])) + else: + transformed.append(t) + else: + if common.is_field(f): + slices = _get_slices_from_domain_slice(f.domain, d) + transformed.append(xp.asarray(f.ndarray[slices])) + else: + transformed.append(f) + mask_index = t_domain.dim_index(mask_dim) + result = xp.concatenate(transformed, axis=mask_index) + result_domain = t_domain.replace( + mask_dim, + ( mask_dim, - ( - mask_dim, - common.UnitRange(false_domain[mask_dim][1].start, true_domain[mask_dim][1].stop), + common.UnitRange( + intersected_domains[0][1][mask_index][1].start, + intersected_domains[-1][1][mask_index][1].stop, ), - ) - # concatenate false then true - else: - raise ValueError("Mask does not split the domain into two non-overlapping parts.") - + ), + ) return t.__class__.from_array(result, domain=result_domain) +# def _concat_where(m, t, f): +# xp = t.__class__.array_ns +# if m.domain.ndim != 1: +# raise ValueError("Can only concatenate fields with a 1-dimensional mask.") +# mask_dim = m.domain.dims[0] +# mask_true_domain = _relative_ranges_to_domain(_hypercube_from_mask(m.ndarray, xp), m.domain) +# t_domain = t.domain if common.is_field(t) else common.Domain() +# true_domain = embedded_common.intersect_domains(t_domain, mask_true_domain) +# mask_false_domain = _relative_ranges_to_domain(_hypercube_from_mask(~m.ndarray, xp), m.domain) +# f_domain = f.domain if common.is_field(f) else common.Domain() +# false_domain = embedded_common.intersect_domains(f_domain, mask_false_domain) + +# if common.is_field(t): +# true_slices = _get_slices_from_domain_slice(t.domain, true_domain) +# true_transformed = xp.asarray(t.ndarray[true_slices]) +# else: +# true_transformed = t +# if common.is_field(f): +# false_slices = _get_slices_from_domain_slice(f.domain, false_domain) +# false_transformed = xp.asarray(f.ndarray[false_slices]) +# else: +# false_transformed = f + +# assert true_domain.dims == false_domain.dims # TODO implement broadcasting +# mask_index = true_domain.dim_index(mask_dim) +# if true_domain[mask_dim][1].stop == false_domain[mask_dim][1].start: +# result = xp.concatenate((true_transformed, false_transformed), axis=mask_index) +# result_domain = true_domain.replace( +# mask_dim, +# ( +# mask_dim, +# common.UnitRange(true_domain[mask_dim][1].start, false_domain[mask_dim][1].stop), +# ), +# ) +# elif false_domain[mask_dim][1].stop == true_domain[mask_dim][1].start: +# result = xp.concatenate((false_transformed, true_transformed), axis=mask_index) +# result_domain = true_domain.replace( +# mask_dim, +# ( +# mask_dim, +# common.UnitRange(false_domain[mask_dim][1].start, true_domain[mask_dim][1].stop), +# ), +# ) +# # concatenate false then true +# else: +# raise ValueError("Mask does not split the domain into two non-overlapping parts.") + +# return t.__class__.from_array(result, domain=result_domain) + + NdArrayField.register_builtin_func(fbuiltins.concat_where, _concat_where) From 75d1b03da71f846e36805c83dc5efe426985c1b2 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 22 Feb 2024 10:41:14 +0100 Subject: [PATCH 21/50] SKIP_VALUE -> _DEFAULT_SKIP_VALUE --- src/gt4py/next/common.py | 2 +- src/gt4py/next/constructors.py | 2 +- src/gt4py/next/embedded/nd_array_field.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 4 +++- src/gt4py/next/iterator/atlas_utils.py | 2 +- src/gt4py/next/iterator/embedded.py | 4 ++-- .../runners/dace_iterator/itir_to_tasklet.py | 2 +- .../feature_tests/ffront_tests/ffront_test_utils.py | 12 ++++++------ .../feature_tests/ffront_tests/test_execution.py | 6 +++--- .../ffront_tests/test_external_local_field.py | 2 +- .../ffront_tests/test_gt4py_builtins.py | 12 ++++++------ 11 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 28ba3ab9d6..f4e35b5533 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -1092,4 +1092,4 @@ def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Call #: Numeric value used to represent missing values in connectivities. #: Equivalent to the `_FillValue` attribute in the UGRID Conventions #: (see: http://ugrid-conventions.github.io/ugrid-conventions/). -SKIP_VALUE: Final[int] = -1 +_DEFAULT_SKIP_VALUE: Final[int] = -1 diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 6628d56e24..f8e7b9bff8 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -332,7 +332,7 @@ def as_connectivity( ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape. """ assert ( - skip_value is None or skip_value == common.SKIP_VALUE + skip_value is None or skip_value == common._DEFAULT_SKIP_VALUE ) # TODO(havogt): not yet configurable if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): domain = cast(Sequence[common.Dimension], domain) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 5e6e4a5bf5..1bdb7161ec 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -535,7 +535,7 @@ def _builtin_op( for d in field.domain.dims ) masked_array = xp.where( - xp.asarray(offset_definition.table[broadcast_slice]) != common.SKIP_VALUE, + xp.asarray(offset_definition.table[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, field.ndarray, initial_value_op(field), ) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 2fb58ad699..34251a36b1 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -365,7 +365,9 @@ def as_connectivity_field(self): codomain=self.source, data=offset_definition.table, dtype=offset_definition.index_type, - skip_value=common.SKIP_VALUE if offset_definition.has_skip_values else None, + skip_value=( + common._DEFAULT_SKIP_VALUE if offset_definition.has_skip_values else None + ), ) else: raise NotImplementedError() diff --git a/src/gt4py/next/iterator/atlas_utils.py b/src/gt4py/next/iterator/atlas_utils.py index 7b9b60fd75..a23a4b5148 100644 --- a/src/gt4py/next/iterator/atlas_utils.py +++ b/src/gt4py/next/iterator/atlas_utils.py @@ -31,7 +31,7 @@ def __getitem__(self, indices): if neigh_index < self.atlas_connectivity.cols(primary_index): return self.atlas_connectivity[primary_index, neigh_index] else: - return common.SKIP_VALUE + return common._DEFAULT_SKIP_VALUE else: if neigh_index < 2: return self.atlas_connectivity[primary_index, neigh_index] diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index a45b81a773..80fc539283 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -532,7 +532,7 @@ def execute_shift( assert common.is_int_index(cur_index) if offset_implementation.mapped_index(cur_index, index) in [ None, - common.SKIP_VALUE, + common._DEFAULT_SKIP_VALUE, ]: return None @@ -559,7 +559,7 @@ def execute_shift( assert common.is_int_index(cur_index) if offset_implementation.mapped_index(cur_index, index) in [ None, - common.SKIP_VALUE, + common._DEFAULT_SKIP_VALUE, ]: return None else: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 3a33ee1e35..cf6d7ab047 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -22,7 +22,7 @@ import gt4py.eve.codegen from gt4py import eve from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing -from gt4py.next.common import SKIP_VALUE as neighbor_skip_value +from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.ir import FunCall, Lambda diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 8aaba85503..aca601d74e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -239,12 +239,12 @@ def skip_value_mesh() -> MeshDescriptor: v2e_arr = np.array( [ - [1, 8, 7, 0, common.SKIP_VALUE], - [2, 8, 1, common.SKIP_VALUE, common.SKIP_VALUE], - [3, 9, 8, 2, common.SKIP_VALUE], - [4, 10, 3, common.SKIP_VALUE, common.SKIP_VALUE], - [5, 11, 4, common.SKIP_VALUE, common.SKIP_VALUE], - [0, 6, 4, common.SKIP_VALUE, common.SKIP_VALUE], + [1, 8, 7, 0, common._DEFAULT_SKIP_VALUE], + [2, 8, 1, common._DEFAULT_SKIP_VALUE, common._DEFAULT_SKIP_VALUE], + [3, 9, 8, 2, common._DEFAULT_SKIP_VALUE], + [4, 10, 3, common._DEFAULT_SKIP_VALUE, common._DEFAULT_SKIP_VALUE], + [5, 11, 4, common._DEFAULT_SKIP_VALUE, common._DEFAULT_SKIP_VALUE], + [0, 6, 4, common._DEFAULT_SKIP_VALUE, common._DEFAULT_SKIP_VALUE], [6, 7, 9, 10, 11], ], dtype=gtx.IndexType, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index e499f83f86..e370c87e63 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -530,7 +530,7 @@ def testee(a: cases.VField) -> cases.VField: initial=0, )[unstructured_case.offset_provider["V2E"].table], axis=1, - where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE, + where=unstructured_case.offset_provider["V2E"].table != common._DEFAULT_SKIP_VALUE, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), ) @@ -595,7 +595,7 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: e[v2e.table] + np.tile(v, (v2e.max_neighbors, 1)).T, axis=1, initial=0, - where=v2e.table != common.SKIP_VALUE, + where=v2e.table != common._DEFAULT_SKIP_VALUE, )[unstructured_case.offset_provider["E2V"].table[:, 0]], ) @@ -735,7 +735,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: b[v2e_table], axis=1, initial=0, - where=v2e_table != common.SKIP_VALUE, + where=v2e_table != common._DEFAULT_SKIP_VALUE, ) ), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 569d7b5631..16d40384d4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -54,7 +54,7 @@ def testee( v2e_table, axis=1, initial=0, - where=v2e_table != common.SKIP_VALUE, + where=v2e_table != common._DEFAULT_SKIP_VALUE, ), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 880f84274f..583246f50f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -63,7 +63,7 @@ def testee(edge_f: cases.EField) -> cases.VField: inp.asnumpy()[v2e_table], axis=1, initial=np.min(inp.asnumpy()), - where=v2e_table != common.SKIP_VALUE, + where=v2e_table != common._DEFAULT_SKIP_VALUE, ) cases.verify(unstructured_case, testee, inp, ref=ref, out=out) @@ -83,7 +83,7 @@ def minover(edge_f: cases.EField) -> cases.VField: edge_f[v2e_table], axis=1, initial=np.max(edge_f), - where=v2e_table != common.SKIP_VALUE, + where=v2e_table != common._DEFAULT_SKIP_VALUE, ), ) @@ -134,7 +134,7 @@ def test_neighbor_sum(unstructured_case, fop): edge_f.asnumpy()[adv_indexing], axis=local_dim_idx, initial=0, - where=broadcasted_table != common.SKIP_VALUE, + where=broadcasted_table != common._DEFAULT_SKIP_VALUE, ) cases.verify( unstructured_case, @@ -177,7 +177,7 @@ def fencil(edge_f: EKField, out: VKField): field.asnumpy()[:, 1][v2e_table], axis=1, initial=0, - where=v2e_table != common.SKIP_VALUE, + where=v2e_table != common._DEFAULT_SKIP_VALUE, ).reshape(out.shape), offset_provider=unstructured_case.offset_provider | {"Koff": KDim}, ) @@ -205,7 +205,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): -edge_f[v2e_table] ** 2 * 2, axis=1, initial=0, - where=v2e_table != common.SKIP_VALUE, + where=v2e_table != common._DEFAULT_SKIP_VALUE, ), ) @@ -224,7 +224,7 @@ def testee(flux: cases.EField) -> cases.VField: flux[v2e_table] * 2, axis=1, initial=0, - where=v2e_table != common.SKIP_VALUE, + where=v2e_table != common._DEFAULT_SKIP_VALUE, ), ) From cbfd82405511d5e1071c273c01a505ad24d2ec8c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Feb 2024 16:44:35 +0100 Subject: [PATCH 22/50] format --- src/gt4py/next/embedded/nd_array_field.py | 19 +++++++++++-------- .../ffront/foast_passes/type_deduction.py | 6 +++--- src/gt4py/next/ffront/foast_to_itir.py | 3 +-- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e57adad38c..0042273b8e 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -43,7 +43,7 @@ def _make_builtin( - builtin_name: str, array_builtin_name: str, reversed=False + builtin_name: str, array_builtin_name: str, reverse=False ) -> Callable[..., NdArrayField]: def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: first = None @@ -77,7 +77,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: else: assert core_defs.is_scalar_type(f) transformed.append(f) - if reversed: + if reverse: transformed = transformed[::-1] new_data = op(*transformed) return first.__class__.from_array(new_data, domain=domain_intersection) @@ -260,15 +260,15 @@ def __setitem__( __pos__ = _make_builtin("pos", "positive") __sub__ = _make_builtin("sub", "subtract") - __rsub__ = _make_builtin("sub", "subtract", reversed=True) + __rsub__ = _make_builtin("sub", "subtract", reverse=True) __mul__ = __rmul__ = _make_builtin("mul", "multiply") __truediv__ = _make_builtin("div", "divide") - __rtruediv__ = _make_builtin("div", "divide", reversed=True) + __rtruediv__ = _make_builtin("div", "divide", reverse=True) __floordiv__ = _make_builtin("floordiv", "floor_divide") - __rfloordiv__ = _make_builtin("floordiv", "floor_divide", reversed=True) + __rfloordiv__ = _make_builtin("floordiv", "floor_divide", reverse=True) __pow__ = _make_builtin("pow", "power") @@ -448,7 +448,7 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field: def _relative_ranges_to_domain( - relative_ranges: tuple[common.UnitRange, ...], domain: common.Domain + relative_ranges: Sequence[common.UnitRange], domain: common.Domain ) -> common.Domain: return common.Domain( dims=domain.dims, ranges=[rr + ar.start for ar, rr in zip(domain.ranges, relative_ranges)] @@ -523,7 +523,10 @@ def _hypercube( def _make_reduction( builtin_name: str, array_builtin_name: str, initial_value_op: Callable -) -> Callable[..., NdArrayField[common.DimsT, core_defs.ScalarT],]: +) -> Callable[ + ..., + NdArrayField[common.DimsT, core_defs.ScalarT], +]: def _builtin_op( field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension ) -> NdArrayField[common.DimsT, core_defs.ScalarT]: @@ -642,7 +645,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] domain_slice.append(slice(None)) named_ranges.append((dim, field.domain[pos][1])) else: - domain_slice.append(field.__class__.array_ns.newaxis) + domain_slice.append(None) # np.newaxis named_ranges.append((dim, common.UnitRange.infinite())) return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 64bdc38e73..6af1570fc9 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -438,9 +438,9 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: f"got types '{true_type}' and '{false_type}.", ) # TODO: properly patch symtable (new node?) - symtable[sym].type = new_node.annex.propagated_symbols[ - sym - ].type = new_true_branch.annex.symtable[sym].type + symtable[sym].type = new_node.annex.propagated_symbols[sym].type = ( + new_true_branch.annex.symtable[sym].type + ) return new_node diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index bad68890df..dfaefb7211 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -443,5 +443,4 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) -class FieldOperatorLoweringError(Exception): - ... +class FieldOperatorLoweringError(Exception): ... From 971dc448b638ef7cbdddbbc1fc0d2ae7e220dda1 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Feb 2024 17:08:21 +0100 Subject: [PATCH 23/50] add tests for scalar binary with field --- src/gt4py/next/embedded/nd_array_field.py | 33 ++++++++++++++----- .../embedded_tests/test_nd_array_field.py | 21 ++++++++---- 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 1bdb7161ec..9e7d227019 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -42,10 +42,16 @@ jnp: Optional[ModuleType] = None # type:ignore[no-redef] -def _make_builtin(builtin_name: str, array_builtin_name: str) -> Callable[..., NdArrayField]: +def _make_builtin( + builtin_name: str, array_builtin_name: str, reverse=False +) -> Callable[..., NdArrayField]: def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: - first = fields[0] - assert isinstance(first, NdArrayField) + first = None + for f in fields: + if isinstance(f, NdArrayField): + first = f + break + assert first is not None xp = first.__class__.array_ns op = getattr(xp, array_builtin_name) @@ -59,7 +65,11 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: if f.domain == domain_intersection: transformed.append(xp.asarray(f.ndarray)) else: - f_broadcasted = _broadcast(f, domain_intersection.dims) + f_broadcasted = ( + _broadcast(f, domain_intersection.dims) + if f.domain.dims != domain_intersection.dims + else f + ) f_slices = _get_slices_from_domain_slice( f_broadcasted.domain, domain_intersection ) @@ -67,7 +77,8 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: else: assert core_defs.is_scalar_type(f) transformed.append(f) - + if reverse: + transformed = transformed[::-1] new_data = op(*transformed) return first.__class__.from_array(new_data, domain=domain_intersection) @@ -248,17 +259,21 @@ def __setitem__( __pos__ = _make_builtin("pos", "positive") - __sub__ = __rsub__ = _make_builtin("sub", "subtract") + __sub__ = _make_builtin("sub", "subtract") + __rsub__ = _make_builtin("sub", "subtract", reverse=True) __mul__ = __rmul__ = _make_builtin("mul", "multiply") - __truediv__ = __rtruediv__ = _make_builtin("div", "divide") + __truediv__ = _make_builtin("div", "divide") + __rtruediv__ = _make_builtin("div", "divide", reverse=True) - __floordiv__ = __rfloordiv__ = _make_builtin("floordiv", "floor_divide") + __floordiv__ = _make_builtin("floordiv", "floor_divide") + __rfloordiv__ = _make_builtin("floordiv", "floor_divide", reverse=True) __pow__ = _make_builtin("pow", "power") - __mod__ = __rmod__ = _make_builtin("mod", "mod") + __mod__ = _make_builtin("mod", "mod") + __rmod__ = _make_builtin("mod", "mod", reverse=True) __ne__ = _make_builtin("not_equal", "not_equal") # type: ignore # mypy wants return `bool` diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 66189cf9eb..f9131d4501 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -156,14 +156,21 @@ def test_where_builtin_with_tuple(nd_array_implementation): assert np.allclose(result[1].ndarray, expected1) -def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): - inp_a = [-1.0, 4.2, 42] - inp_b = [2.0, 3.0, -3.0] - inputs = [inp_a, inp_b] - - expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) +@pytest.mark.parametrize("lhs, rhs", [([-1.0, 4.2, 42], [2.0, 3.0, -3.0]), (1.0, [2.0, 3.0, -3.0])]) +def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation, lhs, rhs): + inputs = [lhs, rhs] + + expected = binary_arithmetic_op( + *[ + (np.asarray(inp, dtype=np.float32) if isinstance(inp, list) else np.float32(inp)) + for inp in inputs + ] + ) - field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] + field_inputs = [ + (_make_field(inp, nd_array_implementation) if isinstance(inp, list) else np.float32(inp)) + for inp in inputs + ] result = binary_arithmetic_op(*field_inputs) From ceb1a099eb219ee6537dcf10b1fff6799cfed87c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Feb 2024 17:23:45 +0100 Subject: [PATCH 24/50] cleanup tests --- .../embedded_tests/test_nd_array_field.py | 84 ++++++++++++------- 1 file changed, 54 insertions(+), 30 deletions(-) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index f9131d4501..5d7ec3574e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -19,6 +19,7 @@ import numpy as np import pytest +from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.common import Dimension, Domain, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field @@ -69,9 +70,14 @@ def unary_logical_op(request): yield request.param -def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=None): +def _make_field_or_scalar( + lst: Iterable | core_defs.Scalar, nd_array_implementation, *, domain=None, dtype=None +): + """Creates a field from an Iterable or returns a scalar.""" if not dtype: - dtype = nd_array_implementation.float32 + dtype = np.float32 + if isinstance(lst, core_defs.SCALAR_TYPES): + return dtype(lst) buffer = nd_array_implementation.asarray(lst, dtype=dtype) if domain is None: domain = tuple( @@ -83,6 +89,18 @@ def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=No ) +def _np_asarray_or_scalar(value: Iterable | core_defs.Scalar, dtype=None): + """Creates a numpy array from an Iterable or returns a scalar.""" + if not dtype: + dtype = np.float32 + + return ( + dtype(value) + if isinstance(value, core_defs.SCALAR_TYPES) + else np.asarray(value, dtype=dtype) + ) + + @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementation): if builtin_name == "gamma": @@ -94,7 +112,7 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati expected = ref_impl(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) - field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] + field_inputs = [_make_field_or_scalar(inp, nd_array_implementation) for inp in inputs] builtin = getattr(fbuiltins, builtin_name) result = builtin(*field_inputs) @@ -107,7 +125,9 @@ def test_where_builtin(nd_array_implementation): true_ = np.asarray([1.0, 2.0], dtype=np.float32) false_ = np.asarray([3.0, 4.0], dtype=np.float32) - field_inputs = [_make_field(inp, nd_array_implementation) for inp in [cond, true_, false_]] + field_inputs = [ + _make_field_or_scalar(inp, nd_array_implementation) for inp in [cond, true_, false_] + ] expected = np.where(cond, true_, false_) result = fbuiltins.where(*field_inputs) @@ -147,9 +167,13 @@ def test_where_builtin_with_tuple(nd_array_implementation): expected0 = np.where(cond, true0, false0) expected1 = np.where(cond, true1, false1) - cond_field = _make_field(cond, nd_array_implementation, dtype=bool) - field_true = tuple(_make_field(inp, nd_array_implementation) for inp in [true0, true1]) - field_false = tuple(_make_field(inp, nd_array_implementation) for inp in [false0, false1]) + cond_field = _make_field_or_scalar(cond, nd_array_implementation, dtype=bool) + field_true = tuple( + _make_field_or_scalar(inp, nd_array_implementation) for inp in [true0, true1] + ) + field_false = tuple( + _make_field_or_scalar(inp, nd_array_implementation) for inp in [false0, false1] + ) result = fbuiltins.where(cond_field, field_true, field_false) assert np.allclose(result[0].ndarray, expected0) @@ -160,31 +184,31 @@ def test_where_builtin_with_tuple(nd_array_implementation): def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation, lhs, rhs): inputs = [lhs, rhs] - expected = binary_arithmetic_op( - *[ - (np.asarray(inp, dtype=np.float32) if isinstance(inp, list) else np.float32(inp)) - for inp in inputs - ] - ) + expected = binary_arithmetic_op(*[_np_asarray_or_scalar(inp) for inp in inputs]) - field_inputs = [ - (_make_field(inp, nd_array_implementation) if isinstance(inp, list) else np.float32(inp)) - for inp in inputs - ] + field_inputs = [_make_field_or_scalar(inp, nd_array_implementation) for inp in inputs] result = binary_arithmetic_op(*field_inputs) assert np.allclose(result.ndarray, expected) -def test_binary_logical_ops(binary_logical_op, nd_array_implementation): - inp_a = [True, True, False, False] - inp_b = [True, False, True, False] - inputs = [inp_a, inp_b] +@pytest.mark.parametrize( + "lhs, rhs", + [ + ([True, True, False, False], [True, False, True, False]), + (True, [True, False]), + (False, [True, False]), + ], +) +def test_binary_logical_ops(binary_logical_op, nd_array_implementation, lhs, rhs): + inputs = [lhs, rhs] - expected = binary_logical_op(*[np.asarray(inp) for inp in inputs]) + expected = binary_logical_op(*[_np_asarray_or_scalar(inp, dtype=bool) for inp in inputs]) - field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs] + field_inputs = [ + _make_field_or_scalar(inp, nd_array_implementation, dtype=bool) for inp in inputs + ] result = binary_logical_op(*field_inputs) @@ -199,7 +223,7 @@ def test_unary_logical_ops(unary_logical_op, nd_array_implementation): expected = unary_logical_op(np.asarray(inp)) - field_input = _make_field(inp, nd_array_implementation, dtype=bool) + field_input = _make_field_or_scalar(inp, nd_array_implementation, dtype=bool) result = unary_logical_op(field_input) @@ -211,7 +235,7 @@ def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32)) - field_input = _make_field(inp, nd_array_implementation) + field_input = _make_field_or_scalar(inp, nd_array_implementation) result = unary_arithmetic_op(field_input) @@ -262,8 +286,8 @@ def test_mixed_fields(product_nd_array_implementation): expected = np.asarray(inp_a) + np.asarray(inp_b) - field_inp_a = _make_field(inp_a, first_impl) - field_inp_b = _make_field(inp_b, second_impl) + field_inp_a = _make_field_or_scalar(inp_a, first_impl) + field_inp_b = _make_field_or_scalar(inp_b, second_impl) result = field_inp_a + field_inp_b assert np.allclose(result.ndarray, expected) @@ -280,9 +304,9 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field: expected = np.asarray(inp_a) * np.asarray(inp_b) + np.asarray(inp_c) - field_inp_a = _make_field(inp_a, np) - field_inp_b = _make_field(inp_b, np) - field_inp_c = _make_field(inp_c, np) + field_inp_a = _make_field_or_scalar(inp_a, np) + field_inp_b = _make_field_or_scalar(inp_b, np) + field_inp_c = _make_field_or_scalar(inp_c, np) result = fma(field_inp_a, field_inp_b, field_inp_c) assert np.allclose(result.ndarray, expected) From ddf6667d318f244aee6c171432dcd76b6e803a6e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Feb 2024 17:31:07 +0100 Subject: [PATCH 25/50] cleanup --- src/gt4py/next/embedded/nd_array_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9e7d227019..b00f3f25d2 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -78,7 +78,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: assert core_defs.is_scalar_type(f) transformed.append(f) if reverse: - transformed = transformed[::-1] + transformed.reverse() new_data = op(*transformed) return first.__class__.from_array(new_data, domain=domain_intersection) From 864ddd36c94d77e9bb484a5aef79d308fac537a6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Feb 2024 17:48:00 +0100 Subject: [PATCH 26/50] add concat_where for embedded --- src/gt4py/next/embedded/nd_array_field.py | 81 +++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 1b0a7148b0..5f72470a6b 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -522,6 +522,87 @@ def _hypercube( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) +def _mask_ranges( + mask: core_defs.NDArrayObject, +) -> list[tuple[bool, common.UnitRange]]: + # TODO: does it make sense to upgrade this naive algorithm to numpy? + assert mask.ndim == 1 + cur = bool(mask[0].item()) + ind = 0 + res = [] + for i in range(1, mask.shape[0]): + if (mask_i := bool(mask[i].item)) != cur: + res.append((cur, common.UnitRange(ind, i))) + cur = mask_i + ind = i + res.append((cur, common.UnitRange(ind, mask.shape[0]))) + return res + + +def _trim(lst: list[tuple[bool, common.Domain]]) -> list[tuple[bool, common.Domain]]: + res = [] + found = False + for v, d in lst: + if d != common.Domain(): + res.append((v, d)) + found = True + else: + if found: + raise ValueError("Out of bounds.") + return res + + +def _concat_where(m, t, f): + xp = t.__class__.array_ns + if m.domain.ndim != 1: + raise ValueError("Can only concatenate fields with a 1-dimensional mask.") + mask_dim = m.domain.dims[0] + + relative_mask_ranges = _mask_ranges(m.ndarray) + mask_domains = [ + (v, _relative_ranges_to_domain((mr,), m.domain)) for v, mr in relative_mask_ranges + ] + t_domain = t.domain if common.is_field(t) else common.Domain() + f_domain = f.domain if common.is_field(f) else common.Domain() + intersected_domains = [ + (v, embedded_common.intersect_domains(t_domain if v else f_domain, d)) + for v, d in mask_domains + ] + + assert t.domain.dims == f.domain.dims # TODO implement broadcasting + + transformed = [] + for v, d in _trim(intersected_domains): + if v: + if common.is_field(t): + slices = _get_slices_from_domain_slice(t.domain, d) + transformed.append(xp.asarray(t.ndarray[slices])) + else: + transformed.append(t) + else: + if common.is_field(f): + slices = _get_slices_from_domain_slice(f.domain, d) + transformed.append(xp.asarray(f.ndarray[slices])) + else: + transformed.append(f) + mask_index = t_domain.dim_index(mask_dim) + result = xp.concatenate(transformed, axis=mask_index) + result_domain = t_domain.replace( + mask_dim, + ( + mask_dim, + common.UnitRange( + intersected_domains[0][1][mask_index][1].start, + intersected_domains[-1][1][mask_index][1].stop, + ), + ), + ) + return t.__class__.from_array(result, domain=result_domain) + + +NdArrayField.register_builtin_func(fbuiltins.concat_where, _concat_where) + + def _make_reduction( builtin_name: str, array_builtin_name: str, initial_value_op: Callable ) -> Callable[ From a028503976e3ff10adca334d03e260a8c128b0b9 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Feb 2024 23:03:13 +0100 Subject: [PATCH 27/50] add tests and very hacked version of broadcasting --- src/gt4py/next/common.py | 4 + src/gt4py/next/embedded/exceptions.py | 8 + src/gt4py/next/embedded/nd_array_field.py | 139 ++++++++++++------ .../embedded_tests/test_nd_array_field.py | 131 ++++++++++++++++- 4 files changed, 238 insertions(+), 44 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index f4e35b5533..87aadcd8ea 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -413,6 +413,10 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: # classmethod since TypeGuards requires the guarded obj as separate argument return all(UnitRange.is_finite(rng) for rng in obj.ranges) + @property + def is_empty(self) -> bool: + return any(rng == UnitRange(0, 0) for rng in self.ranges) + @overload def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index 393123db36..bddea25712 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -36,3 +36,11 @@ def __init__( self.indices = indices self.index = index self.dim = dim + + +class NonContiguousDomain(gt4py_exceptions.GT4PyError): + msg: str + + def __init__(self, msg: str): + super().__init__(f"Operation would result in a non-contiguous domain: `{msg}`.") + self.msg = msg diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 5f72470a6b..a357d9e5ef 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -26,7 +26,11 @@ from gt4py._core import definitions as core_defs from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar from gt4py.next import common -from gt4py.next.embedded import common as embedded_common, context as embedded_context +from gt4py.next.embedded import ( + common as embedded_common, + context as embedded_context, + exceptions as embedded_exceptions, +) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import embedded as itir_embedded @@ -42,17 +46,19 @@ jnp: Optional[ModuleType] = None # type:ignore[no-redef] +def _get_nd_array_class(*fields: common.Field | core_defs.Scalar) -> type[NdArrayField]: + for f in fields: + if isinstance(f, NdArrayField): + return f.__class__ + raise AssertionError("No 'NdArrayField' found in the arguments.") + + def _make_builtin( builtin_name: str, array_builtin_name: str, reverse=False ) -> Callable[..., NdArrayField]: def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: - first = None - for f in fields: - if isinstance(f, NdArrayField): - first = f - break - assert first is not None - xp = first.__class__.array_ns + cls_ = _get_nd_array_class(*fields) + xp = cls_.array_ns op = getattr(xp, array_builtin_name) domain_intersection = embedded_common.intersect_domains( @@ -80,7 +86,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: if reverse: transformed.reverse() new_data = op(*transformed) - return first.__class__.from_array(new_data, domain=domain_intersection) + return cls_.from_array(new_data, domain=domain_intersection) _builtin_op.__name__ = builtin_name return _builtin_op @@ -525,13 +531,14 @@ def _hypercube( def _mask_ranges( mask: core_defs.NDArrayObject, ) -> list[tuple[bool, common.UnitRange]]: + """Take a 1-dimensional mask and return a sequence of mappings from boolean values to ranges.""" # TODO: does it make sense to upgrade this naive algorithm to numpy? assert mask.ndim == 1 cur = bool(mask[0].item()) ind = 0 res = [] for i in range(1, mask.shape[0]): - if (mask_i := bool(mask[i].item)) != cur: + if (mask_i := bool(mask[i].item())) != cur: res.append((cur, common.UnitRange(ind, i))) cur = mask_i ind = i @@ -539,68 +546,114 @@ def _mask_ranges( return res -def _trim(lst: list[tuple[bool, common.Domain]]) -> list[tuple[bool, common.Domain]]: - res = [] - found = False - for v, d in lst: - if d != common.Domain(): - res.append((v, d)) - found = True - else: - if found: - raise ValueError("Out of bounds.") - return res +def _trim_empty_domains(lst: list[tuple[bool, common.Domain]]) -> list[tuple[bool, common.Domain]]: + """Remove empty domains from beginning and end of the list.""" + if not lst: + return lst + if lst[0][1].is_empty: + return _trim_empty_domains(lst[1:]) + if lst[-1][1].is_empty: + return _trim_empty_domains(lst[:-1]) + return lst + + +def _intersect_domains_orthogonal_to( + dim: common.Dimension, *domains: common.Domain +) -> tuple[common.Domain, ...]: + intersection_orthogonal_to_dim = embedded_common.intersect_domains( + *[d.replace(dim) for d in domains] + ) + return tuple( + common.Domain( + *[(d, r if d == dim else intersection_orthogonal_to_dim[d][1]) for d, r in domain] + ) + for domain in domains + ) -def _concat_where(m, t, f): - xp = t.__class__.array_ns +def _concat_where(m: common.Field, t: common.Field, f: common.Field) -> common.Field: + cls_ = _get_nd_array_class(m, t, f) + xp = cls_.array_ns if m.domain.ndim != 1: - raise ValueError("Can only concatenate fields with a 1-dimensional mask.") + raise NotImplementedError( + "'concat_where': Can only concatenate fields with a 1-dimensional mask." + ) mask_dim = m.domain.dims[0] + promoted_dims = common.promote_dims(m.domain.dims, t.domain.dims, f.domain.dims) + t_broadcasted = _broadcast(t, promoted_dims) + f_broadcasted = _broadcast(f, promoted_dims) relative_mask_ranges = _mask_ranges(m.ndarray) mask_domains = [ (v, _relative_ranges_to_domain((mr,), m.domain)) for v, mr in relative_mask_ranges ] - t_domain = t.domain if common.is_field(t) else common.Domain() - f_domain = f.domain if common.is_field(f) else common.Domain() + t_domain = t_broadcasted.domain if common.is_field(t) else common.Domain() + f_domain = f_broadcasted.domain if common.is_field(f) else common.Domain() + intersected_domains = [ (v, embedded_common.intersect_domains(t_domain if v else f_domain, d)) for v, d in mask_domains ] - assert t.domain.dims == f.domain.dims # TODO implement broadcasting + intersected_domains = _trim_empty_domains(intersected_domains) + if any(d.is_empty for _, d in intersected_domains): + raise embedded_exceptions.NonContiguousDomain( + f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in intersected_domains]}." + ) + + # TODO now intersect them in the dimensions orthogonal to the mask + + foo = _intersect_domains_orthogonal_to(mask_dim, *[domain for _, domain in intersected_domains]) + + intersected_domains = [(v, f) for (v, _), f in zip(intersected_domains, foo)] transformed = [] - for v, d in _trim(intersected_domains): + for v, d in intersected_domains: if v: if common.is_field(t): - slices = _get_slices_from_domain_slice(t.domain, d) - transformed.append(xp.asarray(t.ndarray[slices])) + slices = _get_slices_from_domain_slice(t_broadcasted.domain, d) + transformed.append( + xp.asarray(xp.broadcast_to(t_broadcasted.ndarray[slices], d.shape)) + ) else: transformed.append(t) else: if common.is_field(f): - slices = _get_slices_from_domain_slice(f.domain, d) - transformed.append(xp.asarray(f.ndarray[slices])) + slices = _get_slices_from_domain_slice(f_broadcasted.domain, d) + transformed.append( + xp.asarray(xp.broadcast_to(f_broadcasted.ndarray[slices], d.shape)) + ) else: transformed.append(f) mask_index = t_domain.dim_index(mask_dim) - result = xp.concatenate(transformed, axis=mask_index) - result_domain = t_domain.replace( - mask_dim, - ( + assert mask_index is not None # for mypy + result_domain = ( + intersected_domains[0][1].replace( mask_dim, - common.UnitRange( - intersected_domains[0][1][mask_index][1].start, - intersected_domains[-1][1][mask_index][1].stop, + ( + ( + mask_dim, + ( + common.UnitRange( + intersected_domains[0][1][mask_index][1].start, + intersected_domains[-1][1][mask_index][1].stop, + ) + ), + ) ), - ), + ) + if intersected_domains + else common.Domain((mask_dim, common.UnitRange(0, 0))) + ) + result = ( + xp.concatenate(transformed, axis=mask_index) + if transformed + else xp.empty(result_domain.shape) ) - return t.__class__.from_array(result, domain=result_domain) + return cls_.from_array(result, domain=result_domain) -NdArrayField.register_builtin_func(fbuiltins.concat_where, _concat_where) +NdArrayField.register_builtin_func(fbuiltins.concat_where, _concat_where) # type: ignore[arg-type] # tuples are handled in the base implementation def _make_reduction( @@ -719,7 +772,7 @@ def __setitem__( common._field.register(jnp.ndarray, JaxArrayField.from_array) -def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: +def _broadcast(field: common.Field, new_dimensions: Sequence[common.Dimension]) -> common.Field: domain_slice: list[slice | None] = [] named_ranges = [] for dim in new_dimensions: diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 5d7ec3574e..24dfa5e30f 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -14,7 +14,7 @@ import itertools import math import operator -from typing import Callable, Iterable +from typing import Callable, Iterable, Optional import numpy as np import pytest @@ -859,3 +859,132 @@ def test_hypercube(index_array, expected): result = nd_array_field._hypercube(index_array, image_range, np, skip_value) assert result == expected + + +@pytest.mark.parametrize( + "mask_data, true_data, false_data, expected", + [ + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], None), + ([6, 7, 8, 9, 10], None), + ([1, 7, 3, 9, 5], None), + ), + ( + ([True, False, True, False], None), + ([1, 2, 3, 4, 5], common.UnitRange(-2, 3)), + ([6, 7, 8, 9], common.UnitRange(1, 5)), + ([3, 6, 5, 8], common.UnitRange(0, 4)), + ), + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], common.UnitRange(-2, 3)), + ([6, 7, 8, 9, 10], common.UnitRange(1, 6)), + ([3, 6, 5, 8], common.UnitRange(0, 4)), + ), + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], common.UnitRange(-2, 3)), + ([6, 7, 8, 9, 10], common.UnitRange(2, 7)), + None, + ), + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], common.UnitRange(-5, 0)), + ([6, 7, 8, 9, 10], common.UnitRange(5, 10)), + ([], common.UnitRange(0, 0)), + ), + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], common.UnitRange(-4, 1)), + ([6, 7, 8, 9, 10], common.UnitRange(5, 10)), + ([5], common.UnitRange(0, 1)), + ), + ], +) +def test_concat_where_1D( + nd_array_implementation, + mask_data: tuple[list[bool], Optional[common.UnitRange]], + true_data: tuple[list[int], Optional[common.UnitRange]], + false_data: tuple[list[int], Optional[common.UnitRange]], + expected: Optional[tuple[list[int], Optional[common.UnitRange]]], +): + mask_lst, mask_range = mask_data + true_lst, true_range = true_data + false_lst, false_range = false_data + + mask_field = _make_field_or_scalar( + mask_lst, + nd_array_implementation=nd_array_implementation, + domain=common.Domain( + (common.Dimension("D"), mask_range or common.unit_range(len(mask_lst))) + ), + dtype=bool, + ) + true_field = _make_field_or_scalar( + true_lst, + nd_array_implementation=nd_array_implementation, + domain=common.Domain( + (common.Dimension("D"), true_range or common.unit_range(len(true_lst))) + ), + dtype=np.int32, + ) + false_field = _make_field_or_scalar( + false_lst, + nd_array_implementation=nd_array_implementation, + domain=common.Domain( + (common.Dimension("D"), false_range or common.unit_range(len(false_lst))) + ), + dtype=np.int32, + ) + + if expected is None: + with pytest.raises(embedded_exceptions.NonContiguousDomain): + nd_array_field._concat_where(mask_field, true_field, false_field) + else: + expected_lst, expected_range = expected + expected_array = np.asarray(expected_lst) + expected_domain = common.Domain( + (common.Dimension("D"), expected_range or common.unit_range(len(expected_lst))) + ) + + result = nd_array_field._concat_where(mask_field, true_field, false_field) + + assert expected_domain == result.domain + np.testing.assert_allclose(result.asnumpy(), expected_array) + + +def test_concat_where_broadcasting(nd_array_implementation): + mask_field = _make_field_or_scalar( + [True, False, True, False, True], + nd_array_implementation=nd_array_implementation, + domain=common.Domain((common.Dimension("D"), common.unit_range(5))), + dtype=bool, + ) + + true_field = _make_field_or_scalar( + [1, 2, 3, 4, 5], + nd_array_implementation=nd_array_implementation, + domain=common.Domain((common.Dimension("D"), common.unit_range(5))), + dtype=np.int32, + ) + false_field = _make_field_or_scalar( + [[6, 11], [7, 12], [8, 13], [9, 14], [10, 15]], + nd_array_implementation=nd_array_implementation, + domain=common.Domain( + (common.Dimension("D"), common.unit_range(5)), + (common.Dimension("DExtra"), common.unit_range(2)), + ), + dtype=np.int32, + ) + + expected_array = np.asarray([[1, 1], [7, 12], [3, 3], [9, 14], [5, 5]]) + expected_domain = common.Domain( + (common.Dimension("D"), common.unit_range(5)), + (common.Dimension("DExtra"), common.unit_range(2)), + ) + + result = nd_array_field._concat_where(mask_field, true_field, false_field) + + assert expected_domain == result.domain + np.testing.assert_allclose(result.asnumpy(), expected_array) From 05ae764eef9c57d39742e7ce4de667e6f7632b7a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 23 Feb 2024 23:06:10 +0100 Subject: [PATCH 28/50] add TODOs --- .../unit_tests/embedded_tests/test_nd_array_field.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 24dfa5e30f..d1df04d8f5 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -988,3 +988,9 @@ def test_concat_where_broadcasting(nd_array_implementation): assert expected_domain == result.domain np.testing.assert_allclose(result.asnumpy(), expected_array) + + +# TODO test +# - where one of the field doesn't have the mask dimension +# - (where the mask dimension is not the first one) +# - a scalar is involved From 19a176f335cb80841eb0d93317abfea8dd629461 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Sat, 24 Feb 2024 15:42:11 +0100 Subject: [PATCH 29/50] cleanup --- src/gt4py/next/embedded/nd_array_field.py | 54 ++-- .../embedded_tests/test_nd_array_field.py | 267 ++++++++---------- 2 files changed, 148 insertions(+), 173 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index a357d9e5ef..c42caf3a99 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -571,6 +571,19 @@ def _intersect_domains_orthogonal_to( ) +def _to_field( + value: common.Field | core_defs.Scalar, nd_array_field_type: type[NdArrayField] +) -> common.Field: + # TODO(havogt): this function is only to workaround broadcasting of scalars, once we have a ConstantField, we can broadcast to that directly + return ( + value + if common.is_field(value) + else nd_array_field_type.from_array( + nd_array_field_type.array_ns.asarray(value), domain=common.Domain() + ) + ) + + def _concat_where(m: common.Field, t: common.Field, f: common.Field) -> common.Field: cls_ = _get_nd_array_class(m, t, f) xp = cls_.array_ns @@ -579,19 +592,22 @@ def _concat_where(m: common.Field, t: common.Field, f: common.Field) -> common.F "'concat_where': Can only concatenate fields with a 1-dimensional mask." ) mask_dim = m.domain.dims[0] - promoted_dims = common.promote_dims(m.domain.dims, t.domain.dims, f.domain.dims) - t_broadcasted = _broadcast(t, promoted_dims) - f_broadcasted = _broadcast(f, promoted_dims) + + promoted_dims = common.promote_dims(*[f.domain.dims for f in [m, t, f] if common.is_field(f)]) + t_broadcasted = _broadcast(_to_field(t, cls_), promoted_dims) + f_broadcasted = _broadcast(_to_field(f, cls_), promoted_dims) relative_mask_ranges = _mask_ranges(m.ndarray) mask_domains = [ (v, _relative_ranges_to_domain((mr,), m.domain)) for v, mr in relative_mask_ranges ] - t_domain = t_broadcasted.domain if common.is_field(t) else common.Domain() - f_domain = f_broadcasted.domain if common.is_field(f) else common.Domain() - intersected_domains = [ - (v, embedded_common.intersect_domains(t_domain if v else f_domain, d)) + ( + v, + embedded_common.intersect_domains( + t_broadcasted.domain if v else f_broadcasted.domain, d + ), + ) for v, d in mask_domains ] @@ -601,31 +617,19 @@ def _concat_where(m: common.Field, t: common.Field, f: common.Field) -> common.F f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in intersected_domains]}." ) - # TODO now intersect them in the dimensions orthogonal to the mask - + # TODO cleanup foo = _intersect_domains_orthogonal_to(mask_dim, *[domain for _, domain in intersected_domains]) - intersected_domains = [(v, f) for (v, _), f in zip(intersected_domains, foo)] transformed = [] for v, d in intersected_domains: if v: - if common.is_field(t): - slices = _get_slices_from_domain_slice(t_broadcasted.domain, d) - transformed.append( - xp.asarray(xp.broadcast_to(t_broadcasted.ndarray[slices], d.shape)) - ) - else: - transformed.append(t) + slices = _get_slices_from_domain_slice(t_broadcasted.domain, d) + transformed.append(xp.asarray(xp.broadcast_to(t_broadcasted.ndarray[slices], d.shape))) else: - if common.is_field(f): - slices = _get_slices_from_domain_slice(f_broadcasted.domain, d) - transformed.append( - xp.asarray(xp.broadcast_to(f_broadcasted.ndarray[slices], d.shape)) - ) - else: - transformed.append(f) - mask_index = t_domain.dim_index(mask_dim) + slices = _get_slices_from_domain_slice(f_broadcasted.domain, d) + transformed.append(xp.asarray(xp.broadcast_to(f_broadcasted.ndarray[slices], d.shape))) + mask_index = t_broadcasted.domain.dim_index(mask_dim) assert mask_index is not None # for mypy result_domain = ( intersected_domains[0][1].replace( diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index d1df04d8f5..f739105d84 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -29,9 +29,9 @@ from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data -IDim = Dimension("IDim") -JDim = Dimension("JDim") -KDim = Dimension("KDim") +D0 = Dimension("D0") +D1 = Dimension("D1") +D2 = Dimension("D2") @pytest.fixture(params=nd_array_field._nd_array_implementations) @@ -70,6 +70,13 @@ def unary_logical_op(request): yield request.param +def _make_default_domain(shape: tuple[int, ...]) -> Domain: + return common.Domain( + dims=tuple(Dimension(f"D{i}") for i in range(len(shape))), + ranges=tuple(UnitRange(0, s) for s in shape), + ) + + def _make_field_or_scalar( lst: Iterable | core_defs.Scalar, nd_array_implementation, *, domain=None, dtype=None ): @@ -80,9 +87,7 @@ def _make_field_or_scalar( return dtype(lst) buffer = nd_array_implementation.asarray(lst, dtype=dtype) if domain is None: - domain = tuple( - (common.Dimension(f"D{i}"), common.UnitRange(0, s)) for i, s in enumerate(buffer.shape) - ) + domain = _make_default_domain(buffer.shape) return common._field( buffer, domain=domain, @@ -139,16 +144,14 @@ def test_where_builtin_different_domain(nd_array_implementation): true_ = np.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) false_ = np.asarray([7.0, 8.0, 9.0, 10.0], dtype=np.float32) - cond_field = common._field( - nd_array_implementation.asarray(cond), domain=common.domain({JDim: 2}) - ) + cond_field = common._field(nd_array_implementation.asarray(cond), domain=common.domain({D1: 2})) true_field = common._field( nd_array_implementation.asarray(true_), - domain=common.domain({IDim: common.UnitRange(0, 2), JDim: common.UnitRange(-1, 2)}), + domain=common.domain({D0: common.UnitRange(0, 2), D1: common.UnitRange(-1, 2)}), ) false_field = common._field( nd_array_implementation.asarray(false_), - domain=common.domain({JDim: common.UnitRange(-1, 3)}), + domain=common.domain({D1: common.UnitRange(-1, 3)}), ) expected = np.where(cond[np.newaxis, :], true_[:, 1:], false_[np.newaxis, 1:-1]) @@ -245,8 +248,8 @@ def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): @pytest.mark.parametrize( "dims,expected_indices", [ - ((IDim,), (slice(5, 10), None)), - ((JDim,), (None, slice(5, 10))), + ((D0,), (slice(5, 10), None)), + ((D1,), (None, slice(5, 10))), ], ) def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expected_indices): @@ -254,7 +257,7 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte arr1_domain = common.Domain(dims=dims, ranges=(UnitRange(0, 10),)) arr2 = np.ones((5, 5)) - arr2_domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 10), UnitRange(5, 10))) + arr2_domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(5, 10), UnitRange(5, 10))) field1 = common._field(arr1, domain=arr1_domain) field2 = common._field(arr2, domain=arr2_domain) @@ -370,39 +373,39 @@ def test_cartesian_remap_implementation(): [ ( ( - (IDim,), + (D0,), common._field( - np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) + np.arange(10), domain=common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)), + Domain(dims=(D0,), ranges=(UnitRange(0, 10),)), ) ), ( ( - (IDim, JDim), + (D0, D1), common._field( - np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) + np.arange(10), domain=common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinite())), + Domain(dims=(D0, D1), ranges=(UnitRange(0, 10), UnitRange.infinite())), ) ), ( ( - (IDim, JDim), + (D0, D1), common._field( - np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) + np.arange(10), domain=common.Domain(dims=(D1,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange.infinite(), UnitRange(0, 10))), + Domain(dims=(D0, D1), ranges=(UnitRange.infinite(), UnitRange(0, 10))), ) ), ( ( - (IDim, JDim, KDim), + (D0, D1, D2), common._field( - np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) + np.arange(10), domain=common.Domain(dims=(D1,), ranges=(UnitRange(0, 10),)) ), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange.infinite(), UnitRange(0, 10), UnitRange.infinite()), ), ) @@ -417,13 +420,13 @@ def test_field_broadcast(new_dims, field, expected_domain): @pytest.mark.parametrize( "domain_slice", [ - ((IDim, UnitRange(0, 10)),), - common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)), + ((D0, UnitRange(0, 10)),), + common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)), ], ) def test_get_slices_with_named_indices_3d_to_1d(domain_slice): field_domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) + dims=(D0, D1, D2), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) ) slices = _get_slices_from_domain_slice(field_domain, domain_slice) assert slices == (slice(0, 10, None), slice(None), slice(None)) @@ -431,18 +434,18 @@ def test_get_slices_with_named_indices_3d_to_1d(domain_slice): def test_get_slices_with_named_index(): field_domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) + dims=(D0, D1, D2), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) ) - named_index = ((IDim, UnitRange(0, 10)), (JDim, 2), (KDim, 3)) + named_index = ((D0, UnitRange(0, 10)), (D1, 2), (D2, 3)) slices = _get_slices_from_domain_slice(field_domain, named_index) assert slices == (slice(0, 10, None), 2, 3) def test_get_slices_invalid_type(): field_domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) + dims=(D0, D1, D2), ranges=(UnitRange(0, 10), UnitRange(0, 10), UnitRange(0, 10)) ) - new_domain = ((IDim, "1"),) + new_domain = ((D0, "1"),) with pytest.raises(ValueError): _get_slices_from_domain_slice(field_domain, new_domain) @@ -452,39 +455,39 @@ def test_get_slices_invalid_type(): [ ( ( - (IDim, UnitRange(7, 9)), - (JDim, UnitRange(8, 10)), + (D0, UnitRange(7, 9)), + (D1, UnitRange(8, 10)), ), - (IDim, JDim, KDim), + (D0, D1, D2), (2, 2, 15), ), ( ( - (IDim, UnitRange(7, 9)), - (KDim, UnitRange(12, 20)), + (D0, UnitRange(7, 9)), + (D2, UnitRange(12, 20)), ), - (IDim, JDim, KDim), + (D0, D1, D2), (2, 10, 8), ), - (common.Domain(dims=(IDim,), ranges=(UnitRange(7, 9),)), (IDim, JDim, KDim), (2, 10, 15)), - (((IDim, 8),), (JDim, KDim), (10, 15)), - (((JDim, 9),), (IDim, KDim), (5, 15)), - (((KDim, 11),), (IDim, JDim), (5, 10)), + (common.Domain(dims=(D0,), ranges=(UnitRange(7, 9),)), (D0, D1, D2), (2, 10, 15)), + (((D0, 8),), (D1, D2), (10, 15)), + (((D1, 9),), (D0, D2), (5, 15)), + (((D2, 11),), (D0, D1), (5, 10)), ( ( - (IDim, 8), - (JDim, UnitRange(8, 10)), + (D0, 8), + (D1, UnitRange(8, 10)), ), - (JDim, KDim), + (D1, D2), (2, 15), ), - ((IDim, 5), (JDim, KDim), (10, 15)), - ((IDim, UnitRange(5, 7)), (IDim, JDim, KDim), (2, 10, 15)), + ((D0, 5), (D1, D2), (10, 15)), + ((D0, UnitRange(5, 7)), (D0, D1, D2), (2, 10, 15)), ], ) def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) indexed_field = field[domain_slice] @@ -495,10 +498,10 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): def test_absolute_indexing_value_return(): - domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) + domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(10, 20), UnitRange(5, 15))) field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) - named_index = ((IDim, 12), (JDim, 6)) + named_index = ((D0, 12), (D1, 6)) assert common.is_field(field) value = field[named_index] @@ -512,28 +515,28 @@ def test_absolute_indexing_value_return(): ( (slice(None, 5), slice(None, 2)), (5, 2), - Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 4))), + Domain((D0, UnitRange(5, 10)), (D1, UnitRange(2, 4))), ), - ((slice(None, 5),), (5, 10), Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 12)))), + ((slice(None, 5),), (5, 10), Domain((D0, UnitRange(5, 10)), (D1, UnitRange(2, 12)))), ( (Ellipsis, 1), (10,), - Domain((IDim, UnitRange(5, 15))), + Domain((D0, UnitRange(5, 15))), ), ( (slice(2, 3), slice(5, 7)), (1, 2), - Domain((IDim, UnitRange(7, 8)), (JDim, UnitRange(7, 9))), + Domain((D0, UnitRange(7, 8)), (D1, UnitRange(7, 9))), ), ( (slice(1, 2), 0), (1,), - Domain((IDim, UnitRange(6, 7))), + Domain((D0, UnitRange(6, 7))), ), ], ) def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): - domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) + domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(5, 15), UnitRange(2, 12))) field = common._field(np.ones((10, 10)), domain=domain) indexed_field = field[index] @@ -545,17 +548,17 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): @pytest.mark.parametrize( "index, expected_shape, expected_domain", [ - ((1, slice(None), 2), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), + ((1, slice(None), 2), (15,), Domain(dims=(D1,), ranges=(UnitRange(10, 25),))), ( (slice(None), slice(None), 2), (10, 15), - Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(10, 25))), + Domain(dims=(D0, D1), ranges=(UnitRange(5, 15), UnitRange(10, 25))), ), ( (slice(None),), (10, 15, 10), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), ), ), @@ -563,7 +566,7 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): (slice(None), slice(None), slice(None)), (10, 15, 10), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), ), ), @@ -571,16 +574,16 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): (slice(None)), (10, 15, 10), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), ), ), - ((0, Ellipsis, 0), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), + ((0, Ellipsis, 0), (15,), Domain(dims=(D1,), ranges=(UnitRange(10, 25),))), ( Ellipsis, (10, 15, 10), Domain( - dims=(IDim, JDim, KDim), + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), ), ), @@ -588,7 +591,7 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): ) def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) ) field = common._field(np.ones((10, 15, 10)), domain=domain) indexed_field = field[index] @@ -606,7 +609,7 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): ], ) def test_relative_indexing_value_return(index, expected_value): - domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) + domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(5, 15), UnitRange(2, 12))) field = common._field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain) indexed_field = field[index] @@ -615,16 +618,16 @@ def test_relative_indexing_value_return(index, expected_value): @pytest.mark.parametrize("lazy_slice", [lambda f: f[13], lambda f: f[:5, :3, :2]]) def test_relative_indexing_out_of_bounds(lazy_slice): - domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) + domain = common.Domain(dims=(D0, D1), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) field = common._field(np.ones((10, 10)), domain=domain) with pytest.raises((embedded_exceptions.IndexOutOfBounds, IndexError)): lazy_slice(field) -@pytest.mark.parametrize("index", [IDim, "1", (IDim, JDim)]) +@pytest.mark.parametrize("index", [D0, "1", (D0, D1)]) def test_field_unsupported_index(index): - domain = common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) + domain = common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) field = common._field(np.ones((10,)), domain=domain) with pytest.raises(IndexError, match="Unsupported index type"): field[index] @@ -637,14 +640,14 @@ def test_field_unsupported_index(index): ((1, slice(None)), np.ones((10,)) * 42.0), ( (1, slice(None)), - common._field(np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(0, 10)))), + common._field(np.ones((10,)) * 42.0, domain=common.Domain((D1, UnitRange(0, 10)))), ), ], ) def test_setitem(index, value): field = common._field( np.arange(100).reshape(10, 10), - domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + domain=common.Domain(dims=(D0, D1), ranges=(UnitRange(0, 10), UnitRange(0, 10))), ) expected = np.copy(field.asnumpy()) @@ -658,11 +661,11 @@ def test_setitem(index, value): def test_setitem_wrong_domain(): field = common._field( np.arange(100).reshape(10, 10), - domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + domain=common.Domain(dims=(D0, D1), ranges=(UnitRange(0, 10), UnitRange(0, 10))), ) value_incompatible = common._field( - np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) + np.ones((10,)) * 42.0, domain=common.Domain((D1, UnitRange(-5, 5))) ) with pytest.raises(ValueError, match=r"Incompatible 'Domain'.*"): @@ -872,69 +875,77 @@ def test_hypercube(index_array, expected): ), ( ([True, False, True, False], None), - ([1, 2, 3, 4, 5], common.UnitRange(-2, 3)), - ([6, 7, 8, 9], common.UnitRange(1, 5)), - ([3, 6, 5, 8], common.UnitRange(0, 4)), + ([1, 2, 3, 4, 5], {D0: (-2, 3)}), + ([6, 7, 8, 9], {D0: (1, 5)}), + ([3, 6, 5, 8], {D0: (0, 4)}), ), ( ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], common.UnitRange(-2, 3)), - ([6, 7, 8, 9, 10], common.UnitRange(1, 6)), - ([3, 6, 5, 8], common.UnitRange(0, 4)), + ([1, 2, 3, 4, 5], {D0: (-2, 3)}), + ([6, 7, 8, 9, 10], {D0: (1, 6)}), + ([3, 6, 5, 8], {D0: (0, 4)}), ), ( ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], common.UnitRange(-2, 3)), - ([6, 7, 8, 9, 10], common.UnitRange(2, 7)), + ([1, 2, 3, 4, 5], {D0: (-2, 3)}), + ([6, 7, 8, 9, 10], {D0: (2, 7)}), None, ), ( + # empty result domain ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], common.UnitRange(-5, 0)), - ([6, 7, 8, 9, 10], common.UnitRange(5, 10)), - ([], common.UnitRange(0, 0)), + ([1, 2, 3, 4, 5], {D0: (-5, 0)}), + ([6, 7, 8, 9, 10], {D0: (5, 10)}), + ([], {D0: (0, 0)}), + ), + ( + ([True, False, True, False, True], None), + ([1, 2, 3, 4, 5], {D0: (-4, 1)}), + ([6, 7, 8, 9, 10], {D0: (5, 10)}), + ([5], {D0: (0, 1)}), + ), + ( + # broadcasting true_field + ([True, False, True, False, True], {D0: 5}), + ([1, 2, 3, 4, 5], {D0: 5}), + ([[6, 11], [7, 12], [8, 13], [9, 14], [10, 15]], {D0: 5, D1: 2}), + ([[1, 1], [7, 12], [3, 3], [9, 14], [5, 5]], {D0: 5, D1: 2}), ), ( ([True, False, True, False, True], None), - ([1, 2, 3, 4, 5], common.UnitRange(-4, 1)), - ([6, 7, 8, 9, 10], common.UnitRange(5, 10)), - ([5], common.UnitRange(0, 1)), + (42, None), + ([6, 7, 8, 9, 10], None), + ([42, 7, 42, 9, 42], None), ), ], ) def test_concat_where_1D( nd_array_implementation, - mask_data: tuple[list[bool], Optional[common.UnitRange]], - true_data: tuple[list[int], Optional[common.UnitRange]], - false_data: tuple[list[int], Optional[common.UnitRange]], - expected: Optional[tuple[list[int], Optional[common.UnitRange]]], + mask_data: tuple[list[bool], Optional[common.DomainLike]], + true_data: tuple[list[int], Optional[common.DomainLike]], + false_data: tuple[list[int], Optional[common.DomainLike]], + expected: Optional[tuple[list[int], Optional[common.DomainLike]]], ): - mask_lst, mask_range = mask_data - true_lst, true_range = true_data - false_lst, false_range = false_data + mask_lst, mask_domain = mask_data + true_lst, true_domain = true_data + false_lst, false_domain = false_data mask_field = _make_field_or_scalar( mask_lst, nd_array_implementation=nd_array_implementation, - domain=common.Domain( - (common.Dimension("D"), mask_range or common.unit_range(len(mask_lst))) - ), + domain=common.domain(mask_domain) if mask_domain is not None else None, dtype=bool, ) true_field = _make_field_or_scalar( true_lst, nd_array_implementation=nd_array_implementation, - domain=common.Domain( - (common.Dimension("D"), true_range or common.unit_range(len(true_lst))) - ), + domain=common.domain(true_domain) if true_domain is not None else None, dtype=np.int32, ) false_field = _make_field_or_scalar( false_lst, nd_array_implementation=nd_array_implementation, - domain=common.Domain( - (common.Dimension("D"), false_range or common.unit_range(len(false_lst))) - ), + domain=common.domain(false_domain) if false_domain is not None else None, dtype=np.int32, ) @@ -942,55 +953,15 @@ def test_concat_where_1D( with pytest.raises(embedded_exceptions.NonContiguousDomain): nd_array_field._concat_where(mask_field, true_field, false_field) else: - expected_lst, expected_range = expected + expected_lst, expected_domain_like = expected expected_array = np.asarray(expected_lst) - expected_domain = common.Domain( - (common.Dimension("D"), expected_range or common.unit_range(len(expected_lst))) + expected_domain = ( + common.domain(expected_domain_like) + if expected_domain_like is not None + else _make_default_domain(expected_array.shape) ) result = nd_array_field._concat_where(mask_field, true_field, false_field) assert expected_domain == result.domain np.testing.assert_allclose(result.asnumpy(), expected_array) - - -def test_concat_where_broadcasting(nd_array_implementation): - mask_field = _make_field_or_scalar( - [True, False, True, False, True], - nd_array_implementation=nd_array_implementation, - domain=common.Domain((common.Dimension("D"), common.unit_range(5))), - dtype=bool, - ) - - true_field = _make_field_or_scalar( - [1, 2, 3, 4, 5], - nd_array_implementation=nd_array_implementation, - domain=common.Domain((common.Dimension("D"), common.unit_range(5))), - dtype=np.int32, - ) - false_field = _make_field_or_scalar( - [[6, 11], [7, 12], [8, 13], [9, 14], [10, 15]], - nd_array_implementation=nd_array_implementation, - domain=common.Domain( - (common.Dimension("D"), common.unit_range(5)), - (common.Dimension("DExtra"), common.unit_range(2)), - ), - dtype=np.int32, - ) - - expected_array = np.asarray([[1, 1], [7, 12], [3, 3], [9, 14], [5, 5]]) - expected_domain = common.Domain( - (common.Dimension("D"), common.unit_range(5)), - (common.Dimension("DExtra"), common.unit_range(2)), - ) - - result = nd_array_field._concat_where(mask_field, true_field, false_field) - - assert expected_domain == result.domain - np.testing.assert_allclose(result.asnumpy(), expected_array) - - -# TODO test -# - where one of the field doesn't have the mask dimension -# - (where the mask dimension is not the first one) -# - a scalar is involved From 6dd79d8458e9536ce98ea8f70776a516b8f6a89d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Sat, 24 Feb 2024 17:25:10 +0100 Subject: [PATCH 30/50] refactoring --- src/gt4py/next/embedded/nd_array_field.py | 75 ++++++++----------- .../embedded_tests/test_nd_array_field.py | 2 +- 2 files changed, 33 insertions(+), 44 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index c42caf3a99..906f810b5a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -528,7 +528,7 @@ def _hypercube( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _mask_ranges( +def _compute_mask_ranges( mask: core_defs.NDArrayObject, ) -> list[tuple[bool, common.UnitRange]]: """Take a 1-dimensional mask and return a sequence of mappings from boolean values to ranges.""" @@ -597,64 +597,53 @@ def _concat_where(m: common.Field, t: common.Field, f: common.Field) -> common.F t_broadcasted = _broadcast(_to_field(t, cls_), promoted_dims) f_broadcasted = _broadcast(_to_field(f, cls_), promoted_dims) - relative_mask_ranges = _mask_ranges(m.ndarray) - mask_domains = [ - (v, _relative_ranges_to_domain((mr,), m.domain)) for v, mr in relative_mask_ranges - ] - intersected_domains = [ - ( - v, - embedded_common.intersect_domains( - t_broadcasted.domain if v else f_broadcasted.domain, d - ), + mask_values, mask_relative_ranges = zip(*_compute_mask_ranges(m.ndarray)) + mask_domains = ( + _relative_ranges_to_domain((relative_range,), m.domain) + for relative_range in mask_relative_ranges + ) + intersected_domains = ( + embedded_common.intersect_domains( + t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain ) - for v, d in mask_domains - ] + for mask_value, mask_domain in zip(mask_values, mask_domains) + ) - intersected_domains = _trim_empty_domains(intersected_domains) - if any(d.is_empty for _, d in intersected_domains): + mask_values, intersected_domains = tuple( + zip(*_trim_empty_domains(list(zip(mask_values, intersected_domains)))) + ) or ([], []) + if any(d.is_empty for d in intersected_domains): raise embedded_exceptions.NonContiguousDomain( - f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in intersected_domains]}." + f"In 'concat_where', cannot concatenate the following 'Domain's: {list(intersected_domains)}." ) - # TODO cleanup - foo = _intersect_domains_orthogonal_to(mask_dim, *[domain for _, domain in intersected_domains]) - intersected_domains = [(v, f) for (v, _), f in zip(intersected_domains, foo)] + intersected_domains = _intersect_domains_orthogonal_to(mask_dim, *intersected_domains) transformed = [] - for v, d in intersected_domains: + for v, d in zip(mask_values, intersected_domains): if v: slices = _get_slices_from_domain_slice(t_broadcasted.domain, d) - transformed.append(xp.asarray(xp.broadcast_to(t_broadcasted.ndarray[slices], d.shape))) + transformed.append(xp.broadcast_to(t_broadcasted.ndarray[slices], d.shape)) else: slices = _get_slices_from_domain_slice(f_broadcasted.domain, d) - transformed.append(xp.asarray(xp.broadcast_to(f_broadcasted.ndarray[slices], d.shape))) + transformed.append(xp.broadcast_to(f_broadcasted.ndarray[slices], d.shape)) mask_index = t_broadcasted.domain.dim_index(mask_dim) assert mask_index is not None # for mypy - result_domain = ( - intersected_domains[0][1].replace( + + if intersected_domains: + new_masked_dim_named_range = ( mask_dim, - ( - ( - mask_dim, - ( - common.UnitRange( - intersected_domains[0][1][mask_index][1].start, - intersected_domains[-1][1][mask_index][1].stop, - ) - ), - ) + common.UnitRange( + intersected_domains[0][mask_index][1].start, + intersected_domains[-1][mask_index][1].stop, ), ) - if intersected_domains - else common.Domain((mask_dim, common.UnitRange(0, 0))) - ) - result = ( - xp.concatenate(transformed, axis=mask_index) - if transformed - else xp.empty(result_domain.shape) - ) - return cls_.from_array(result, domain=result_domain) + result_domain = intersected_domains[0].replace(mask_dim, new_masked_dim_named_range) + result_array = xp.concatenate(transformed, axis=mask_index) + else: + result_domain = common.Domain((mask_dim, common.UnitRange(0, 0))) + result_array = xp.empty(result_domain.shape) + return cls_.from_array(result_array, domain=result_domain) NdArrayField.register_builtin_func(fbuiltins.concat_where, _concat_where) # type: ignore[arg-type] # tuples are handled in the base implementation diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index f739105d84..2eaf9be6bf 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -919,7 +919,7 @@ def test_hypercube(index_array, expected): ), ], ) -def test_concat_where_1D( +def test_concat_where( nd_array_implementation, mask_data: tuple[list[bool], Optional[common.DomainLike]], true_data: tuple[list[int], Optional[common.DomainLike]], From ab78dc488781ca2aee43eb758878730112effb3e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Sat, 24 Feb 2024 21:25:32 +0100 Subject: [PATCH 31/50] more cleanups --- src/gt4py/next/embedded/common.py | 45 +++++++- src/gt4py/next/embedded/nd_array_field.py | 121 +++++++++++++--------- src/gt4py/next/embedded/operators.py | 4 +- 3 files changed, 120 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 94efe4d61d..d517c0469b 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -97,7 +97,18 @@ def _absolute_sub_domain( return common.Domain(*named_ranges) -def intersect_domains(*domains: common.Domain) -> common.Domain: +# TODO tests +def domain_intersection( + *domains: common.Domain, +) -> common.Domain: + """ + Return the intersection of the given domains. + + Example: + >>> I = common.Dimension("I") + >>> domain_intersection(common.domain({I:(0,5)}), common.domain({I:(1,3)})) # doctest: +ELLIPSIS + Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),)) + """ return functools.reduce( operator.and_, domains, @@ -105,6 +116,38 @@ def intersect_domains(*domains: common.Domain) -> common.Domain: ) +# TODO tests +def intersect_domains( + *domains: common.Domain, + ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None, +) -> tuple[common.Domain, ...]: + """ + Return the with each other intersected domains, ignoring 'ignore_dims' dimensions for the intersection. + + Example: + >>> I = common.Dimension("I") + >>> J = common.Dimension("J") + >>> res = intersect_domains(common.domain({I:(0,5), J:(1,2)}), common.domain({I:(1,3), J:(0,3)}), ignore_dims=J) + >>> assert res == (common.domain({I:(1,3), J:(1,2)}), common.domain({I:(1,3), J:(0,3)})) + """ + ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,) + intersection_without_ignore_dims = domain_intersection( + *[ + common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple]) + for domain in domains + ] + ) + return tuple( + common.Domain( + *[ + (d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1]) + for d, r in domain + ] + ) + for domain in domains + ) + + def iterate_domain(domain: common.Domain): for i in itertools.product(*[list(r) for r in domain.ranges]): yield tuple(zip(domain.dims, i)) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 906f810b5a..ee0f9105d8 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -61,7 +61,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: xp = cls_.array_ns op = getattr(xp, array_builtin_name) - domain_intersection = embedded_common.intersect_domains( + domain_intersection = embedded_common.domain_intersection( *[f.domain for f in fields if common.is_field(f)] ) @@ -557,20 +557,6 @@ def _trim_empty_domains(lst: list[tuple[bool, common.Domain]]) -> list[tuple[boo return lst -def _intersect_domains_orthogonal_to( - dim: common.Dimension, *domains: common.Domain -) -> tuple[common.Domain, ...]: - intersection_orthogonal_to_dim = embedded_common.intersect_domains( - *[d.replace(dim) for d in domains] - ) - return tuple( - common.Domain( - *[(d, r if d == dim else intersection_orthogonal_to_dim[d][1]) for d, r in domain] - ) - for domain in domains - ) - - def _to_field( value: common.Field | core_defs.Scalar, nd_array_field_type: type[NdArrayField] ) -> common.Field: @@ -584,26 +570,83 @@ def _to_field( ) -def _concat_where(m: common.Field, t: common.Field, f: common.Field) -> common.Field: - cls_ = _get_nd_array_class(m, t, f) +# TODO move to common and test +def _intersect_fields( + *fields: common.Field | core_defs.Scalar, + ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None, +) -> tuple[common.Field, ...]: + nd_array_class = _get_nd_array_class(*fields) + promoted_dims = common.promote_dims(*[f.domain.dims for f in fields if common.is_field(f)]) + broadcasted_fields = [_broadcast(_to_field(f, nd_array_class), promoted_dims) for f in fields] + + intersected_domains = embedded_common.intersect_domains( + *[f.domain for f in broadcasted_fields], ignore_dims=ignore_dims + ) + + return tuple( + nd_array_class.from_array( + f.ndarray[_get_slices_from_domain_slice(f.domain, intersected_domain)], + domain=intersected_domain, + ) + for f, intersected_domain in zip(broadcasted_fields, intersected_domains, strict=True) + ) + + +def _concat_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]: + if not domains: + return common.Domain() + dim_start = domains[0][dim][1].start + dim_stop = dim_start + for domain in domains: + if not domain[dim][1].start == dim_stop: + return None + else: + dim_stop = domain[dim][1].stop + return domains[0].replace(dim, (dim, common.UnitRange(dim_start, dim_stop))) + + +def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: + # TODO(havogt): this function could be extended to a general concat + # currently only concatenate along the given dimension and requires the fields to be ordered + + if ( + len(fields) > 1 + and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty + ): + raise ValueError("Fields to concatenate must not overlap.") + new_domain = _concat_domains(*[f.domain for f in fields], dim=dim) + if new_domain is None: + raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.") + nd_array_class = _get_nd_array_class(*fields) + return nd_array_class.from_array( + nd_array_class.array_ns.concatenate( + [nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], + axis=new_domain.dim_index(dim), + ), + domain=new_domain, + ) + + +def _concat_where( + mask_field: common.Field, true_field: common.Field, false_field: common.Field +) -> common.Field: + cls_ = _get_nd_array_class(mask_field, true_field, false_field) xp = cls_.array_ns - if m.domain.ndim != 1: + if mask_field.domain.ndim != 1: raise NotImplementedError( "'concat_where': Can only concatenate fields with a 1-dimensional mask." ) - mask_dim = m.domain.dims[0] + mask_dim = mask_field.domain.dims[0] - promoted_dims = common.promote_dims(*[f.domain.dims for f in [m, t, f] if common.is_field(f)]) - t_broadcasted = _broadcast(_to_field(t, cls_), promoted_dims) - f_broadcasted = _broadcast(_to_field(f, cls_), promoted_dims) + t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) - mask_values, mask_relative_ranges = zip(*_compute_mask_ranges(m.ndarray)) + mask_values, mask_relative_ranges = zip(*_compute_mask_ranges(mask_field.ndarray)) mask_domains = ( - _relative_ranges_to_domain((relative_range,), m.domain) + _relative_ranges_to_domain((relative_range,), mask_field.domain) for relative_range in mask_relative_ranges ) intersected_domains = ( - embedded_common.intersect_domains( + embedded_common.domain_intersection( t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain ) for mask_value, mask_domain in zip(mask_values, mask_domains) @@ -617,29 +660,13 @@ def _concat_where(m: common.Field, t: common.Field, f: common.Field) -> common.F f"In 'concat_where', cannot concatenate the following 'Domain's: {list(intersected_domains)}." ) - intersected_domains = _intersect_domains_orthogonal_to(mask_dim, *intersected_domains) + transformed = [ + t_broadcasted[d] if v else f_broadcasted[d] + for v, d in zip(mask_values, intersected_domains) + ] - transformed = [] - for v, d in zip(mask_values, intersected_domains): - if v: - slices = _get_slices_from_domain_slice(t_broadcasted.domain, d) - transformed.append(xp.broadcast_to(t_broadcasted.ndarray[slices], d.shape)) - else: - slices = _get_slices_from_domain_slice(f_broadcasted.domain, d) - transformed.append(xp.broadcast_to(f_broadcasted.ndarray[slices], d.shape)) - mask_index = t_broadcasted.domain.dim_index(mask_dim) - assert mask_index is not None # for mypy - - if intersected_domains: - new_masked_dim_named_range = ( - mask_dim, - common.UnitRange( - intersected_domains[0][mask_index][1].start, - intersected_domains[-1][mask_index][1].stop, - ), - ) - result_domain = intersected_domains[0].replace(mask_dim, new_masked_dim_named_range) - result_array = xp.concatenate(transformed, axis=mask_index) + if transformed: + return _concat(*transformed, dim=mask_dim) else: result_domain = common.Domain((mask_dim, common.UnitRange(0, 0))) result_array = xp.empty(result_domain.shape) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 62b683aac6..47695d7c4b 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -82,7 +82,7 @@ def scan_loop(hpos): def _get_out_domain( out: common.MutableField | tuple[common.MutableField | tuple, ...] ) -> common.Domain: - return embedded_common.intersect_domains( + return embedded_common.domain_intersection( *[f.domain for f in utils.flatten_nested_tuple((out,))] ) @@ -150,7 +150,7 @@ def impl(target: common.MutableField, source: common.Field): def _intersect_scan_args( *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] ) -> common.Domain: - return embedded_common.intersect_domains( + return embedded_common.domain_intersection( *[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)] ) From f891e371b4b7d71afbbc6b82166fc2ef4c83f139 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 26 Feb 2024 15:40:30 +0000 Subject: [PATCH 32/50] address review comments --- src/gt4py/next/embedded/nd_array_field.py | 8 +++----- .../unit_tests/embedded_tests/test_nd_array_field.py | 8 +++++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index b00f3f25d2..e44648d27c 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -65,11 +65,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: if f.domain == domain_intersection: transformed.append(xp.asarray(f.ndarray)) else: - f_broadcasted = ( - _broadcast(f, domain_intersection.dims) - if f.domain.dims != domain_intersection.dims - else f - ) + f_broadcasted = _broadcast(f, domain_intersection.dims) f_slices = _get_slices_from_domain_slice( f_broadcasted.domain, domain_intersection ) @@ -634,6 +630,8 @@ def __setitem__( def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: + if field.domain.dims == new_dimensions: + return field domain_slice: list[slice | None] = [] named_ranges = [] for dim in new_dimensions: diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 5d7ec3574e..adbe80fc1b 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -180,7 +180,13 @@ def test_where_builtin_with_tuple(nd_array_implementation): assert np.allclose(result[1].ndarray, expected1) -@pytest.mark.parametrize("lhs, rhs", [([-1.0, 4.2, 42], [2.0, 3.0, -3.0]), (1.0, [2.0, 3.0, -3.0])]) +@pytest.mark.parametrize( + "lhs, rhs", + [ + ([-1.0, 4.2, 42], [2.0, 3.0, -3.0]), + (1.0, [2.0, 3.0, -3.0]), # scalar with field, tests reverse operators + ], +) def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation, lhs, rhs): inputs = [lhs, rhs] From ddcc2729ac3f4eb8edddddd3ed51170196c5427e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 26 Feb 2024 15:42:15 +0000 Subject: [PATCH 33/50] change scalar value --- .../next_tests/unit_tests/embedded_tests/test_nd_array_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index adbe80fc1b..8e4af6d0b6 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -184,7 +184,7 @@ def test_where_builtin_with_tuple(nd_array_implementation): "lhs, rhs", [ ([-1.0, 4.2, 42], [2.0, 3.0, -3.0]), - (1.0, [2.0, 3.0, -3.0]), # scalar with field, tests reverse operators + (2.0, [2.0, 3.0, -3.0]), # scalar with field, tests reverse operators ], ) def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation, lhs, rhs): From 05fd105c4e724f0416243c39627c469188c6dd25 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Mar 2024 14:30:29 +0000 Subject: [PATCH 34/50] default format --- src/gt4py/__init__.py | 2 +- src/gt4py/_core/definitions.py | 12 +- src/gt4py/cartesian/backend/base.py | 7 +- src/gt4py/cartesian/backend/dace_backend.py | 78 ++++----- src/gt4py/cartesian/backend/numpy_backend.py | 14 +- src/gt4py/cartesian/caching.py | 6 +- src/gt4py/cartesian/cli.py | 5 +- .../cartesian/frontend/gtscript_frontend.py | 34 ++-- src/gt4py/cartesian/frontend/node_util.py | 3 - src/gt4py/cartesian/frontend/nodes.py | 8 +- src/gt4py/cartesian/gtc/common.py | 4 +- .../gtc/dace/expansion/daceir_builder.py | 6 +- .../gtc/dace/expansion/tasklet_codegen.py | 4 +- src/gt4py/cartesian/gtc/dace/nodes.py | 10 +- src/gt4py/cartesian/gtc/dace/utils.py | 22 +-- src/gt4py/cartesian/gtc/definitions.py | 6 +- src/gt4py/cartesian/gtc/gtcpp/gtcpp.py | 2 +- .../cartesian/gtc/gtcpp/gtcpp_codegen.py | 10 +- .../cartesian/gtc/passes/gtir_k_boundary.py | 2 +- .../passes/oir_optimizations/temporaries.py | 6 +- .../gtc/passes/oir_optimizations/utils.py | 22 ++- src/gt4py/cartesian/gtscript.py | 4 +- src/gt4py/cartesian/gtscript_imports.py | 1 + src/gt4py/cartesian/lazy_stencil.py | 1 + src/gt4py/cartesian/stencil_object.py | 6 +- .../cartesian/testing/input_strategies.py | 6 +- src/gt4py/cartesian/testing/suites.py | 13 +- src/gt4py/cartesian/utils/attrib.py | 2 +- src/gt4py/cartesian/utils/base.py | 10 +- src/gt4py/cartesian/utils/meta.py | 14 +- src/gt4py/eve/codegen.py | 33 ++-- src/gt4py/eve/concepts.py | 1 - src/gt4py/eve/datamodels/__init__.py | 2 +- src/gt4py/eve/datamodels/core.py | 28 +-- src/gt4py/eve/exceptions.py | 1 - src/gt4py/eve/extended_typing.py | 16 +- src/gt4py/eve/pattern_matching.py | 15 +- src/gt4py/eve/traits.py | 1 - src/gt4py/eve/trees.py | 1 - src/gt4py/eve/type_definitions.py | 12 +- src/gt4py/eve/type_validation.py | 9 +- src/gt4py/eve/utils.py | 160 +++++++++--------- src/gt4py/eve/visitors.py | 1 - src/gt4py/next/common.py | 20 +-- src/gt4py/next/constructors.py | 10 +- src/gt4py/next/embedded/common.py | 39 +++-- src/gt4py/next/embedded/nd_array_field.py | 21 ++- src/gt4py/next/embedded/operators.py | 28 +-- .../ffront/ast_passes/remove_docstrings.py | 14 +- .../next/ffront/ast_passes/simple_assign.py | 6 +- .../ffront/ast_passes/single_static_assign.py | 6 +- src/gt4py/next/ffront/decorator.py | 25 +-- src/gt4py/next/ffront/fbuiltins.py | 10 +- src/gt4py/next/ffront/field_operator_ast.py | 24 ++- src/gt4py/next/ffront/foast_introspection.py | 12 +- .../ffront/foast_passes/type_deduction.py | 46 +++-- src/gt4py/next/ffront/foast_to_itir.py | 4 +- src/gt4py/next/ffront/func_to_foast.py | 8 +- src/gt4py/next/ffront/lowering_utils.py | 43 +++-- src/gt4py/next/ffront/past_to_itir.py | 13 +- src/gt4py/next/ffront/program_ast.py | 12 +- src/gt4py/next/ffront/source_utils.py | 8 +- src/gt4py/next/ffront/type_info.py | 6 +- src/gt4py/next/ffront/type_translation.py | 2 +- src/gt4py/next/iterator/atlas_utils.py | 2 +- src/gt4py/next/iterator/embedded.py | 2 +- src/gt4py/next/iterator/ir.py | 10 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 +- src/gt4py/next/iterator/pretty_printer.py | 1 + .../iterator/transforms/collapse_tuple.py | 2 +- src/gt4py/next/iterator/transforms/cse.py | 33 ++-- .../next/iterator/transforms/global_tmps.py | 16 +- .../iterator/transforms/inline_lambdas.py | 6 +- src/gt4py/next/iterator/type_inference.py | 151 ++++++++--------- src/gt4py/next/otf/binding/nanobind.py | 1 - .../compilation/build_systems/cmake_lists.py | 12 +- src/gt4py/next/otf/compilation/cache.py | 1 - src/gt4py/next/otf/compilation/common.py | 1 - src/gt4py/next/otf/stages.py | 10 +- src/gt4py/next/otf/workflow.py | 38 ++--- .../codegens/gtfn/gtfn_im_ir.py | 2 +- .../codegens/gtfn/gtfn_ir.py | 10 +- .../codegens/gtfn/gtfn_ir_common.py | 4 +- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 6 +- .../codegens/gtfn/gtfn_module.py | 4 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 2 +- .../program_processors/processor_interface.py | 7 +- .../runners/dace_iterator/__init__.py | 13 +- .../runners/dace_iterator/itir_to_tasklet.py | 6 +- .../runners/dace_iterator/utility.py | 8 +- .../next/program_processors/runners/gtfn.py | 18 +- src/gt4py/next/type_system/type_info.py | 71 ++++---- src/gt4py/next/utils.py | 8 +- src/gt4py/storage/__init__.py | 2 +- src/gt4py/storage/cartesian/layout.py | 2 +- .../feature_tests/test_exec_info.py | 5 +- .../stencil_definitions.py | 6 +- .../multi_feature_tests/test_suites.py | 33 ++-- .../frontend_tests/test_defir_to_gtir.py | 3 +- .../frontend_tests/test_gtscript_frontend.py | 42 +++-- .../unit_tests/test_gtc/test_common.py | 16 +- tests/conftest.py | 1 - tests/eve_tests/conftest.py | 1 - .../unit_tests/test_extended_typing.py | 14 +- .../unit_tests/test_type_validation.py | 16 +- .../ffront_tests/ffront_test_utils.py | 13 +- .../ffront_tests/test_execution.py | 58 +++---- .../ffront_tests/test_gt4py_builtins.py | 6 +- .../ffront_tests/test_program.py | 16 +- .../multi_feature_tests/fvm_nabla_setup.py | 6 +- tests/next_tests/past_common_fixtures.py | 2 +- .../embedded_tests/test_nd_array_field.py | 90 +++++----- .../errors_tests/test_exceptions.py | 14 +- .../ffront_tests/test_foast_to_itir.py | 36 ++-- .../ffront_tests/test_func_to_foast.py | 1 + .../iterator_tests/test_pretty_printer.py | 27 +-- tests/storage_tests/conftest.py | 1 - 117 files changed, 901 insertions(+), 918 deletions(-) diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index c28c5cf2d6..7d255de142 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -33,6 +33,6 @@ if _sys.version_info >= (3, 10): - from . import next # noqa: A004 + from . import next __all__ += ["next"] diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index a550db4f2e..440dba9455 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -74,21 +74,24 @@ BoolScalar: TypeAlias = Union[bool_, bool] BoolT = TypeVar("BoolT", bound=BoolScalar) BOOL_TYPES: Final[Tuple[type, ...]] = cast( - Tuple[type, ...], BoolScalar.__args__ # type: ignore[attr-defined] + Tuple[type, ...], + BoolScalar.__args__, # type: ignore[attr-defined] ) IntScalar: TypeAlias = Union[int8, int16, int32, int64, int] IntT = TypeVar("IntT", bound=IntScalar) INT_TYPES: Final[Tuple[type, ...]] = cast( - Tuple[type, ...], IntScalar.__args__ # type: ignore[attr-defined] + Tuple[type, ...], + IntScalar.__args__, # type: ignore[attr-defined] ) UnsignedIntScalar: TypeAlias = Union[uint8, uint16, uint32, uint64] UnsignedIntT = TypeVar("UnsignedIntT", bound=UnsignedIntScalar) UINT_TYPES: Final[Tuple[type, ...]] = cast( - Tuple[type, ...], UnsignedIntScalar.__args__ # type: ignore[attr-defined] + Tuple[type, ...], + UnsignedIntScalar.__args__, # type: ignore[attr-defined] ) @@ -100,7 +103,8 @@ FloatingScalar: TypeAlias = Union[float32, float64, float] FloatingT = TypeVar("FloatingT", bound=FloatingScalar) FLOAT_TYPES: Final[Tuple[type, ...]] = cast( - Tuple[type, ...], FloatingScalar.__args__ # type: ignore[attr-defined] + Tuple[type, ...], + FloatingScalar.__args__, # type: ignore[attr-defined] ) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 669110161e..259e94dcd8 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -409,8 +409,9 @@ def build_extension_module( assert module_name == qualified_pyext_name - self.builder.with_backend_data( - {"pyext_module_name": module_name, "pyext_file_path": file_path} - ) + self.builder.with_backend_data({ + "pyext_module_name": module_name, + "pyext_file_path": file_path, + }) return module_name, file_path diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 0bfdec791f..215ed8af96 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -86,14 +86,12 @@ def _get_expansion_priority_cpu(node: StencilComputation): expansion_priority = [] if node.has_splittable_regions(): expansion_priority.append(["Sections", "Stages", "I", "J", "K"]) - expansion_priority.extend( - [ - ["TileJ", "TileI", "IMap", "JMap", "Sections", "K", "Stages"], - ["TileJ", "TileI", "IMap", "JMap", "Sections", "Stages", "K"], - ["TileJ", "TileI", "Sections", "Stages", "IMap", "JMap", "K"], - ["TileJ", "TileI", "Sections", "K", "Stages", "JMap", "IMap"], - ] - ) + expansion_priority.extend([ + ["TileJ", "TileI", "IMap", "JMap", "Sections", "K", "Stages"], + ["TileJ", "TileI", "IMap", "JMap", "Sections", "Stages", "K"], + ["TileJ", "TileI", "Sections", "Stages", "IMap", "JMap", "K"], + ["TileJ", "TileI", "Sections", "K", "Stages", "JMap", "IMap"], + ]) return expansion_priority @@ -489,18 +487,16 @@ def generate_tmp_allocs(self, sdfg): threadlocal_fmt, "}}", ] - res.extend( - [ - fmt.format( - name=name, - sdfg_id=array_sdfg.sdfg_id, - dtype=array.dtype.ctype, - size=f"omp_max_threads * ({array.total_size})", - local_size=array.total_size, - ) - for fmt in fmts - ] - ) + res.extend([ + fmt.format( + name=name, + sdfg_id=array_sdfg.sdfg_id, + dtype=array.dtype.ctype, + size=f"omp_max_threads * ({array.total_size})", + local_size=array.total_size, + ) + for fmt in fmts + ]) return res @staticmethod @@ -617,22 +613,18 @@ def generate_dace_args(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> List[ # api field strides fmt = "gt::sid::get_stride<{dim}>(gt::sid::get_strides(__{name}_sid))" - symbols.update( - { - f"__{name}_{dim}_stride": fmt.format( - dim=f"gt::stencil::dim::{dim.lower()}", name=name - ) - for dim in dims - } - ) - symbols.update( - { - f"__{name}_d{dim}_stride": fmt.format( - dim=f"gt::integral_constant", name=name - ) - for dim in range(data_ndim) - } - ) + symbols.update({ + f"__{name}_{dim}_stride": fmt.format( + dim=f"gt::stencil::dim::{dim.lower()}", name=name + ) + for dim in dims + }) + symbols.update({ + f"__{name}_d{dim}_stride": fmt.format( + dim=f"gt::integral_constant", name=name + ) + for dim in range(data_ndim) + }) # api field pointers fmt = """gt::sid::multi_shifted( @@ -746,14 +738,12 @@ def apply(cls, stencil_ir: gtir.Stencil, sdfg: dace.SDFG, module_name: str, *, b class DaCePyExtModuleGenerator(PyExtModuleGenerator): def generate_imports(self): - return "\n".join( - [ - *super().generate_imports().splitlines(), - "import dace", - "import copy", - "from gt4py.cartesian.backend.dace_stencil_object import DaCeStencilObject", - ] - ) + return "\n".join([ + *super().generate_imports().splitlines(), + "import dace", + "import copy", + "from gt4py.cartesian.backend.dace_stencil_object import DaCeStencilObject", + ]) def generate_base_class_name(self): return "DaCeStencilObject" diff --git a/src/gt4py/cartesian/backend/numpy_backend.py b/src/gt4py/cartesian/backend/numpy_backend.py index 6f1aab52cf..b43e4d979a 100644 --- a/src/gt4py/cartesian/backend/numpy_backend.py +++ b/src/gt4py/cartesian/backend/numpy_backend.py @@ -42,14 +42,12 @@ def generate_imports(self) -> str: comp_pkg = ( self.builder.caching.module_prefix + "computation" + self.builder.caching.module_postfix ) - return "\n".join( - [ - *super().generate_imports().splitlines(), - "import pathlib", - "from gt4py.cartesian.utils import make_module_from_file", - f'computation = make_module_from_file("{comp_pkg}", pathlib.Path(__file__).parent / "{comp_pkg}.py")', - ] - ) + return "\n".join([ + *super().generate_imports().splitlines(), + "import pathlib", + "from gt4py.cartesian.utils import make_module_from_file", + f'computation = make_module_from_file("{comp_pkg}", pathlib.Path(__file__).parent / "{comp_pkg}.py")', + ]) def generate_implementation(self) -> str: params = [f"{p.name}={p.name}" for p in self.builder.gtir.params] diff --git a/src/gt4py/cartesian/caching.py b/src/gt4py/cartesian/caching.py index 4d716a6c79..1b78973b6d 100644 --- a/src/gt4py/cartesian/caching.py +++ b/src/gt4py/cartesian/caching.py @@ -259,9 +259,9 @@ def is_cache_info_available_and_consistent( and cache_info_ns.module_shash == module_shash ) if validate_extra: - result &= all( - [cache_info[key] == validate_extra[key] for key in validate_extra] - ) + result &= all([ + cache_info[key] == validate_extra[key] for key in validate_extra + ]) except Exception as err: if not catch_exceptions: raise err diff --git a/src/gt4py/cartesian/cli.py b/src/gt4py/cartesian/cli.py index a4bd209c56..4dcb8f1ee4 100644 --- a/src/gt4py/cartesian/cli.py +++ b/src/gt4py/cartesian/cli.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Command line interface.""" + import functools import importlib import pathlib @@ -152,9 +153,7 @@ def convert( try: value = self._convert_value(backend.options[name]["type"], value, param, ctx) except click.BadParameter as conversion_error: - self.fail( - f'Invalid value for backend option "{name}": {conversion_error.message}' # noqa: B306 - ) + self.fail(f'Invalid value for backend option "{name}": {conversion_error.message}') return (name, value) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 2df8c106ce..d08fd8d7ea 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -417,7 +417,7 @@ def visit_Assign(self, node: ast.Assign): else: return self.generic_visit(node) - def visit_Call( # noqa: C901 # Cyclomatic complexity too high + def visit_Call( # Cyclomatic complexity too high self, node: ast.Call, *, target_node=None ): call_name = gt_meta.get_qualified_name_from_node(node.func) @@ -618,7 +618,7 @@ def visit_If(self, node: ast.If): def _make_temp_decls( - descriptors: Dict[str, gtscript._FieldDescriptor] + descriptors: Dict[str, gtscript._FieldDescriptor], ) -> Dict[str, nodes.FieldDecl]: return { name: nodes.FieldDecl( @@ -1495,9 +1495,9 @@ def visit_With(self, node: ast.With): self.parsing_horizontal_region = True intervals_dicts = self._visit_with_horizontal(node.items[0], loc) - all_stmts = gt_utils.flatten( - [gtc_utils.listify(self.visit(stmt)) for stmt in node.body] - ) + all_stmts = gt_utils.flatten([ + gtc_utils.listify(self.visit(stmt)) for stmt in node.body + ]) self.parsing_horizontal_region = False stmts = list(filter(lambda stmt: isinstance(stmt, nodes.Decl), all_stmts)) body_block = nodes.BlockStmt( @@ -1511,12 +1511,10 @@ def visit_With(self, node: ast.With): "The following variables are" f"written before being referenced with an offset in a horizontal region: {', '.join(written_then_offset)}" ) - stmts.extend( - [ - nodes.HorizontalIf(intervals=intervals_dict, body=body_block) - for intervals_dict in intervals_dicts - ] - ) + stmts.extend([ + nodes.HorizontalIf(intervals=intervals_dict, body=body_block) + for intervals_dict in intervals_dicts + ]) return stmts else: # If we find nested `with` blocks flatten them, i.e. transform @@ -1874,14 +1872,12 @@ def resolve_external_symbols( for name, accesses in resolved_imports.items(): if accesses: for attr_name, attr_nodes in accesses.items(): - resolved_values_list.append( - ( - attr_name, - GTScriptParser.eval_external( - attr_name, context, nodes.Location.from_ast_node(attr_nodes[0]) - ), - ) - ) + resolved_values_list.append(( + attr_name, + GTScriptParser.eval_external( + attr_name, context, nodes.Location.from_ast_node(attr_nodes[0]) + ), + )) elif not exhaustive: resolved_values_list.append((name, GTScriptParser.eval_external(name, context))) diff --git a/src/gt4py/cartesian/frontend/node_util.py b/src/gt4py/cartesian/frontend/node_util.py index 9595d5f76a..6436bf2983 100644 --- a/src/gt4py/cartesian/frontend/node_util.py +++ b/src/gt4py/cartesian/frontend/node_util.py @@ -13,7 +13,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later import collections -import copy import operator from typing import Generator, Optional, Type @@ -21,8 +20,6 @@ import gt4py.cartesian.gtc.utils as gtc_utils from gt4py import eve -from gt4py.cartesian import utils as gt_utils -from gt4py.cartesian.gtc import common from .nodes import Location, Node diff --git a/src/gt4py/cartesian/frontend/nodes.py b/src/gt4py/cartesian/frontend/nodes.py index 848f78852c..3f3207e9fe 100644 --- a/src/gt4py/cartesian/frontend/nodes.py +++ b/src/gt4py/cartesian/frontend/nodes.py @@ -142,22 +142,18 @@ import enum import operator import sys -from typing import Generator, List, Optional, Sequence, Type +from typing import List, Optional, Sequence import numpy as np -from gt4py.cartesian.definitions import AccessKind, CartesianSpace -from gt4py.cartesian.gtc.definitions import Extent, Index +from gt4py.cartesian.definitions import CartesianSpace from gt4py.cartesian.utils.attrib import ( Any as Any, Dict as DictOf, List as ListOf, - Optional as OptionalOf, - Tuple as TupleOf, Union as UnionOf, attribkwclass as attribclass, attribute, - attributes_of, ) diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index e1449f11b1..1e0364d721 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -589,7 +589,7 @@ def _impl(cls: Type[ExprT], instance: ExprT) -> None: class _LvalueDimsValidator(eve.VisitorWithSymbolTableTrait): def __init__(self, vertical_loop_type: Type[eve.Node], decl_type: Type[eve.Node]) -> None: - if not vertical_loop_type.__annotations__.get("loop_order") is LoopOrder: + if vertical_loop_type.__annotations__.get("loop_order") is not LoopOrder: raise ValueError( f"Vertical loop type {vertical_loop_type} has no `loop_order` attribute" ) @@ -906,7 +906,7 @@ def data_type_to_typestr(dtype: DataType) -> str: def op_to_ufunc( op: Union[ UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction - ] + ], ) -> np.ufunc: if not isinstance( op, (UnaryOperator, ArithmeticOperator, ComparisonOperator, LogicalOperator, NativeFunction) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 9a214441ad..c882c1bb96 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -437,9 +437,9 @@ def visit_HorizontalExecution( ) expansion_items = global_ctx.library_node.expansion_specification[stages_idx + 1 :] - iteration_ctx = iteration_ctx.push_axes_extents( - {k: v for k, v in zip(dcir.Axis.dims_horizontal(), extent)} - ) + iteration_ctx = iteration_ctx.push_axes_extents({ + k: v for k, v in zip(dcir.Axis.dims_horizontal(), extent) + }) iteration_ctx = iteration_ctx.push_expansion_items(expansion_items) assert iteration_ctx.grid_subset == dcir.GridSubset.single_gridpoint() diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 779fca0c8d..e2ce48ec74 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -58,7 +58,9 @@ def _visit_offset( context_info = copy.deepcopy(access_info) context_info.variable_offset_axes = [] ranges = make_dace_subset( - access_info, context_info, data_dims=() # data_index added in visit_IndexAccess + access_info, + context_info, + data_dims=(), # data_index added in visit_IndexAccess ) ranges.offset(sym_offsets, negative=False) res = dace.subsets.Range([r for i, r in enumerate(ranges.ranges) if int_sizes[i] != 1]) diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index bd8c08034c..7a0db46db5 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -144,12 +144,10 @@ def __init__( for decl in declarations.values() if isinstance(decl, oir.ScalarDecl) } - self.symbol_mapping.update( - { - axis.domain_symbol(): dace.symbol(axis.domain_symbol(), dtype=dace.int32) - for axis in dcir.Axis.dims_horizontal() - } - ) + self.symbol_mapping.update({ + axis.domain_symbol(): dace.symbol(axis.domain_symbol(), dtype=dace.int32) + for axis in dcir.Axis.dims_horizontal() + }) self.access_infos = compute_dcir_access_infos( oir_node, oir_decls=declarations, diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 8d8a0c90f7..dac0c8acc5 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -129,12 +129,9 @@ def visit_VerticalLoopSection( k_grid = dcir.GridSubset.from_interval(grid_subset.intervals[dcir.Axis.K], dcir.Axis.K) inner_infos = {name: info.apply_iteration(k_grid) for name, info in inner_infos.items()} - ctx.access_infos.update( - { - name: info.union(ctx.access_infos.get(name, info)) - for name, info in inner_infos.items() - } - ) + ctx.access_infos.update({ + name: info.union(ctx.access_infos.get(name, info)) for name, info in inner_infos.items() + }) return ctx.access_infos @@ -170,12 +167,9 @@ def visit_HorizontalExecution( inner_infos = {name: info.apply_iteration(ij_grid) for name, info in inner_infos.items()} - ctx.access_infos.update( - { - name: info.union(ctx.access_infos.get(name, info)) - for name, info in inner_infos.items() - } - ) + ctx.access_infos.update({ + name: info.union(ctx.access_infos.get(name, info)) for name, info in inner_infos.items() + }) return ctx.access_infos @@ -419,7 +413,7 @@ def flatten_list(list_or_node: Union[List[Any], eve.Node]): def collect_toplevel_computation_nodes( - list_or_node: Union[List[Any], eve.Node] + list_or_node: Union[List[Any], eve.Node], ) -> List["dcir.ComputationNode"]: class ComputationNodeCollector(eve.NodeVisitor): def visit_ComputationNode(self, node: dcir.ComputationNode, *, collection: List): @@ -431,7 +425,7 @@ def visit_ComputationNode(self, node: dcir.ComputationNode, *, collection: List) def collect_toplevel_iteration_nodes( - list_or_node: Union[List[Any], eve.Node] + list_or_node: Union[List[Any], eve.Node], ) -> List["dcir.IterationNode"]: class IterationNodeCollector(eve.NodeVisitor): def visit_IterationNode(self, node: dcir.IterationNode, *, collection: List): diff --git a/src/gt4py/cartesian/gtc/definitions.py b/src/gt4py/cartesian/gtc/definitions.py index 4925d7cc2b..4b41aabee8 100644 --- a/src/gt4py/cartesian/gtc/definitions.py +++ b/src/gt4py/cartesian/gtc/definitions.py @@ -436,9 +436,9 @@ def _apply(self, other, left_func, right_func=None): raise ValueError("Incompatible instance '{obj}'".format(obj=other)) right_func = right_func or left_func - return type(self)( - [tuple([left_func(a[0], b[0]), right_func(a[1], b[1])]) for a, b in zip(self, other)] - ) + return type(self)([ + tuple([left_func(a[0], b[0]), right_func(a[1], b[1])]) for a, b in zip(self, other) + ]) def _reduce(self, reduce_func, out_type=tuple): return out_type([reduce_func(d[0], d[1]) for d in self]) diff --git a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py index 77d1a14a4d..045dc09377 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py @@ -159,7 +159,7 @@ def __add__(self, offset: Union[common.CartesianOffset, VariableKOffset]) -> "GT class GTAccessor(LocNode): name: eve.Coerced[eve.SymbolName] - id: int # noqa: A003 # shadowing python builtin + id: int # shadowing python builtin intent: Intent extent: GTExtent ndim: int = 3 diff --git a/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py b/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py index 4e56b159d9..7795472b41 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py +++ b/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py @@ -104,12 +104,10 @@ def visit_AccessorRef( if accessor_ref.name in temp_decls and accessor_ref.data_index: # Cannot use symtable. See https://github.com/GridTools/gt4py/issues/808 temp = temp_decls[accessor_ref.name] - data_index = "+".join( - [ - f"{self.visit(index, in_data_index=True, **kwargs)}*{int(np.prod(temp.data_dims[i+1:], initial=1))}" - for i, index in enumerate(accessor_ref.data_index) - ] - ) + data_index = "+".join([ + f"{self.visit(index, in_data_index=True, **kwargs)}*{int(np.prod(temp.data_dims[i+1:], initial=1))}" + for i, index in enumerate(accessor_ref.data_index) + ]) return f"eval({accessor_ref.name}({i_offset}, {j_offset}, {k_offset}))[{data_index}]" else: data_index = "".join( diff --git a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py index 43ab047c6e..24ea38b36a 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py @@ -22,7 +22,7 @@ def _iter_field_names( - node: Union[gtir.Stencil, gtir.ParAssignStmt] + node: Union[gtir.Stencil, gtir.ParAssignStmt], ) -> eve.utils.XIterable[gtir.FieldAccess]: return node.walk_values().if_isinstance(gtir.FieldDecl).getattr("name").unique() diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py index c97b478f77..a44a500c5a 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py @@ -106,9 +106,9 @@ class LocalTemporariesToScalars(TemporariesToScalarsBase): def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> oir.Stencil: horizontal_executions = node.walk_values().if_isinstance(oir.HorizontalExecution) - temps_without_data_dims = set( - [decl.name for decl in node.declarations if not decl.data_dims] - ) + temps_without_data_dims = set([ + decl.name for decl in node.declarations if not decl.data_dims + ]) counts: collections.Counter = sum( ( collections.Counter( diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py index ddf4713757..a0b335d70a 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py @@ -173,18 +173,16 @@ class CartesianAccessCollection(GenericAccessCollection[CartesianAccess, Tuple[i class GeneralAccessCollection(GenericAccessCollection[GeneralAccess, GeneralOffsetTuple]): def cartesian_accesses(self) -> "AccessCollector.CartesianAccessCollection": - return AccessCollector.CartesianAccessCollection( - [ - CartesianAccess( - field=acc.field, - offset=cast(Tuple[int, int, int], acc.offset), - data_index=acc.data_index, - is_write=acc.is_write, - ) - for acc in self._ordered_accesses - if acc.offset[2] is not None - ] - ) + return AccessCollector.CartesianAccessCollection([ + CartesianAccess( + field=acc.field, + offset=cast(Tuple[int, int, int], acc.offset), + data_index=acc.data_index, + is_write=acc.is_write, + ) + for acc in self._ordered_accesses + if acc.offset[2] is not None + ]) def has_variable_access(self) -> bool: return any(acc.offset[2] is None for acc in self._ordered_accesses) diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 09e6f0f4a5..418fefc292 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -22,7 +22,7 @@ import inspect import numbers import types -from typing import Callable, Dict, Optional, Type +from typing import Callable, Dict, Type import numpy as np @@ -651,7 +651,7 @@ def __descriptor__(self): return None def __repr__(self): - args = f"dtype={repr(self.dtype)}, axes={repr(self.axes)}, data_dims={repr(self.data_dims)}" + args = f"dtype={self.dtype!r}, axes={self.axes!r}, data_dims={self.data_dims!r}" return f"_FieldDescriptor({args})" def __str__(self): diff --git a/src/gt4py/cartesian/gtscript_imports.py b/src/gt4py/cartesian/gtscript_imports.py index 82f05e968b..206d190f36 100644 --- a/src/gt4py/cartesian/gtscript_imports.py +++ b/src/gt4py/cartesian/gtscript_imports.py @@ -37,6 +37,7 @@ import ... """ + import importlib import importlib.abc import pathlib diff --git a/src/gt4py/cartesian/lazy_stencil.py b/src/gt4py/cartesian/lazy_stencil.py index 3f8d54ea4f..84b7010cce 100644 --- a/src/gt4py/cartesian/lazy_stencil.py +++ b/src/gt4py/cartesian/lazy_stencil.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Stencil Object that allows for deferred building.""" + from typing import TYPE_CHECKING, Any, Dict from cached_property import cached_property diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index 69ce980bda..c1fe858d62 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -92,7 +92,7 @@ def _extract_array_infos( def _extract_stencil_arrays( - array_infos: Dict[str, Optional[ArgsInfo]] + array_infos: Dict[str, Optional[ArgsInfo]], ) -> Dict[str, Optional[FieldType]]: return {name: info.array if info is not None else None for name, info in array_infos.items()} @@ -283,7 +283,7 @@ def __call__(self, *args, **kwargs) -> None: @staticmethod def _make_origin_dict( - origin: Union[Dict[str, Tuple[int, ...]], Tuple[int, ...], int, None] + origin: Union[Dict[str, Tuple[int, ...]], Tuple[int, ...], int, None], ) -> Dict[str, Tuple[int, ...]]: try: if isinstance(origin, dict): @@ -349,7 +349,7 @@ def _get_max_domain( else: return max_domain - def _validate_args( # noqa: C901 # Function is too complex + def _validate_args( # Function is too complex self, arg_infos: Dict[str, Optional[ArgsInfo]], param_args: Dict[str, Any], diff --git a/src/gt4py/cartesian/testing/input_strategies.py b/src/gt4py/cartesian/testing/input_strategies.py index 008b859929..37646f1af5 100644 --- a/src/gt4py/cartesian/testing/input_strategies.py +++ b/src/gt4py/cartesian/testing/input_strategies.py @@ -178,9 +178,9 @@ def derived_shape_st(shape_st, extra: Sequence[Optional[int]]): both shape and extra elements are summed together. """ return hyp_st.builds( - lambda shape: tuple( - [d + e for d, e in itertools.zip_longest(shape, extra, fillvalue=0) if e is not None] - ), + lambda shape: tuple([ + d + e for d, e in itertools.zip_longest(shape, extra, fillvalue=0) if e is not None + ]), shape_st, ) diff --git a/src/gt4py/cartesian/testing/suites.py b/src/gt4py/cartesian/testing/suites.py index 99ad14f87c..735a314b63 100644 --- a/src/gt4py/cartesian/testing/suites.py +++ b/src/gt4py/cartesian/testing/suites.py @@ -392,10 +392,7 @@ class StencilTestSuite(metaclass=SuiteMeta): .. code-block:: python - { - 'float_symbols' : (np.float32, np.float64), - 'int_symbols' : (int, np.int_, np.int64) - } + {"float_symbols": (np.float32, np.float64), "int_symbols": (int, np.int_, np.int64)} domain_range : `Sequence` of pairs like `((int, int), (int, int) ... )` Required class attribute. @@ -473,7 +470,7 @@ def _test_generation(cls, test, externals_dict): test["implementations"].append(implementation) @classmethod - def _run_test_implementation(cls, parameters_dict, implementation): # noqa: C901 # too complex + def _run_test_implementation(cls, parameters_dict, implementation): # too complex input_data, exec_info = parameters_dict origin = cls.origin @@ -504,9 +501,9 @@ def _run_test_implementation(cls, parameters_dict, implementation): # noqa: C90 referenced_inputs = { name: info for name, info in implementation.field_info.items() if info is not None } - referenced_inputs.update( - {name: info for name, info in implementation.parameter_info.items() if info is not None} - ) + referenced_inputs.update({ + name: info for name, info in implementation.parameter_info.items() if info is not None + }) # set externals for validation method for k, v in implementation.constants.items(): diff --git a/src/gt4py/cartesian/utils/attrib.py b/src/gt4py/cartesian/utils/attrib.py index da53e5c128..46bbf3dcfd 100644 --- a/src/gt4py/cartesian/utils/attrib.py +++ b/src/gt4py/cartesian/utils/attrib.py @@ -172,7 +172,7 @@ def _is_union_of_validator(instance, attribute, value): validator(instance, attribute, value) break - except Exception as e: + except Exception: pass else: passed = False diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index 591c44fb95..dca8a1f420 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -206,12 +206,10 @@ def classmethod_to_function(class_method, instance=None, owner=type(None), remov def namespace_from_nested_dict(nested_dict): assert isinstance(nested_dict, dict) - return types.SimpleNamespace( - **{ - key: namespace_from_nested_dict(value) if isinstance(value, dict) else value - for key, value in nested_dict.items() - } - ) + return types.SimpleNamespace(**{ + key: namespace_from_nested_dict(value) if isinstance(value, dict) else value + for key, value in nested_dict.items() + }) def make_local_dir(dir_name, base_dir=None, *, mode=0o777, is_package=False, is_cache=False): diff --git a/src/gt4py/cartesian/utils/meta.py b/src/gt4py/cartesian/utils/meta.py index 0f21015d1a..769b097f54 100644 --- a/src/gt4py/cartesian/utils/meta.py +++ b/src/gt4py/cartesian/utils/meta.py @@ -96,14 +96,12 @@ def _dump(node: ast.AST, excluded_names): for name, value in sorted(ast.iter_fields(node)) ] - return "".join( - [ - node.__class__.__name__, - "({content})".format( - content=", ".join("{}={}".format(name, value) for name, value in fields) - ), - ] - ) + return "".join([ + node.__class__.__name__, + "({content})".format( + content=", ".join("{}={}".format(name, value) for name, value in fields) + ), + ]) elif isinstance(node, list): lines = ["[", *[_dump(i, excluded_names) + "," for i in node], "]"] diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 72f0e8858f..dd8225ed91 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -14,7 +14,6 @@ """Tools for source code generation.""" - from __future__ import annotations import abc @@ -305,13 +304,13 @@ def indented(self, steps: int = 1) -> Iterator[TextBlock]: common `indent - append - dedent` workflows. Examples: - >>> block = TextBlock(); - >>> block.append('first line') # doctest: +ELLIPSIS + >>> block = TextBlock() + >>> block.append("first line") # doctest: +ELLIPSIS <...> >>> with block.indented(): - ... block.append('second line'); # doctest: +ELLIPSIS + ... block.append("second line") # doctest: +ELLIPSIS <...> - >>> block.append('third line') # doctest: +ELLIPSIS + >>> block.append("third line") # doctest: +ELLIPSIS <...> >>> print(block.text) first line @@ -476,7 +475,9 @@ def render_values(self, **kwargs: Any) -> str: message += f" (created at {self.definition_loc[0]}:{self.definition_loc[1]})" try: loc_info = re.search(r"line (\d+), col (\d+)", str(e)) - message += f" rendering error at template line: {loc_info[1]}, column {loc_info[2]}." # type: ignore + message += ( + f" rendering error at template line: {loc_info[1]}, column {loc_info[2]}." # type: ignore + ) except Exception: message += " rendering error." @@ -541,7 +542,9 @@ def __init__(self, definition: mako_tpl.Template, **kwargs: Any) -> None: if self.definition_loc: message += f" created at {self.definition_loc[0]}:{self.definition_loc[1]}" try: - message += f" (error likely around line {e.lineno}, column: {getattr(e, 'pos', '?')})" # type: ignore # assume Mako exception + message += ( + f" (error likely around line {e.lineno}, column: {getattr(e, 'pos', '?')})" # type: ignore # assume Mako exception + ) except Exception: message = f"{message}:\n---\n{definition}\n---\n" @@ -629,13 +632,11 @@ def __init_subclass__(cls, *, inherit_templates: bool = True, **kwargs: Any) -> ): templates.update(templated_gen_class.__templates__) - templates.update( - { - key: value - for key, value in cls.__dict__.items() - if isinstance(value, Template) and not key.startswith("_") and not key.endswith("_") - } - ) + templates.update({ + key: value + for key, value in cls.__dict__.items() + if isinstance(value, Template) and not key.startswith("_") and not key.endswith("_") + }) cls.__templates__ = types.MappingProxyType(templates) @@ -645,12 +646,12 @@ def apply(cls, root: LeafNode, **kwargs: Any) -> str: ... @overload @classmethod - def apply( # noqa: F811 # redefinition of symbol + def apply( # redefinition of symbol cls, root: CollectionNode, **kwargs: Any ) -> Collection[str]: ... @classmethod - def apply( # noqa: F811 # redefinition of symbol + def apply( # redefinition of symbol cls, root: RootNode, **kwargs: Any ) -> Union[str, Collection[str]]: """Public method to build a class instance and visit an IR node. diff --git a/src/gt4py/eve/concepts.py b/src/gt4py/eve/concepts.py index 67991f6db0..c073033cc3 100644 --- a/src/gt4py/eve/concepts.py +++ b/src/gt4py/eve/concepts.py @@ -14,7 +14,6 @@ """Definitions of basic Eve concepts.""" - from __future__ import annotations import copy diff --git a/src/gt4py/eve/datamodels/__init__.py b/src/gt4py/eve/datamodels/__init__.py index 3b9565bc34..c57ab7c8c2 100644 --- a/src/gt4py/eve/datamodels/__init__.py +++ b/src/gt4py/eve/datamodels/__init__.py @@ -113,5 +113,5 @@ """ -from . import core as core, validators as validators # noqa: F401 # imported but unused +from . import core as core, validators as validators # imported but unused from .core import * # noqa: # star unused import diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 11ad824aab..ba13bdf166 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -34,7 +34,7 @@ import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz - import toolz # noqa: F401 + import toolz from .. import exceptions, extended_typing as xtyping, type_validation as type_val, utils from ..extended_typing import ( @@ -282,7 +282,7 @@ def datamodel( @overload -def datamodel( # noqa: F811 # redefinion of unused symbol +def datamodel( # redefinion of unused symbol cls: Type[_T], /, *, @@ -301,7 +301,7 @@ def datamodel( # noqa: F811 # redefinion of unused symbol # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) -def datamodel( # noqa: F811 # redefinion of unused symbol +def datamodel( # redefinion of unused symbol cls: Optional[Type[_T]] = None, /, *, @@ -559,7 +559,7 @@ def field( >>> from typing import List >>> @datamodel ... class C: - ... mylist: List[int] = field(default_factory=lambda : [1, 2, 3]) + ... mylist: List[int] = field(default_factory=lambda: [1, 2, 3]) >>> c = C() >>> c.mylist [1, 2, 3] @@ -660,7 +660,7 @@ def get_fields(model: Union[DataModel, Type[DataModel]]) -> utils.FrozenNamespac >>> fields(Model) # doctest:+ELLIPSIS FrozenNamespace(...name=Attribute(name='name', default=NOTHING, ... - """ # noqa: RST201 # doctest conventions confuse RST validator + """ # doctest conventions confuse RST validator if not is_datamodel(model): raise TypeError(f"Invalid datamodel instance or class: '{model}'.") if not isinstance(model, type): @@ -694,8 +694,8 @@ def asdict( ... x: int ... y: int >>> c = C(x=1, y=2) - >>> assert asdict(c) == {'x': 1, 'y': 2} - """ # noqa: RST301 # sphinx.napoleon conventions confuse RST validator + >>> assert asdict(c) == {"x": 1, "y": 2} + """ # sphinx.napoleon conventions confuse RST validator if not is_datamodel(instance) or isinstance(instance, type): raise TypeError(f"Invalid datamodel instance: '{instance}'.") return attrs.asdict(instance, value_serializer=value_serializer) @@ -784,7 +784,7 @@ def concretize( *type_args: Type, class_name: Optional[str] = None, module: Optional[str] = None, - support_pickling: bool = True, # noqa + support_pickling: bool = True, overwrite_definition: bool = True, ) -> Type[DataModelT]: """Generate a new concrete subclass of a generic Data Model. @@ -805,9 +805,12 @@ def concretize( overwrite_definition: If ``True``, a previous definition of the class in the target module will be overwritten. - """ # noqa: RST301 # doctest conventions confuse RST validator + """ # doctest conventions confuse RST validator concrete_cls: Type[DataModelT] = _make_concrete_with_cache( - datamodel_cls, *type_args, class_name=class_name, module=module # type: ignore[arg-type] + datamodel_cls, + *type_args, + class_name=class_name, + module=module, # type: ignore[arg-type] ) assert isinstance(concrete_cls, type) and is_datamodel(concrete_cls) @@ -1018,7 +1021,7 @@ def _type_converter(value: Any) -> _T: _KNOWN_MUTABLE_TYPES: Final = (list, dict, set) -def _make_datamodel( # noqa: C901 # too complex but still readable and documented +def _make_datamodel( # too complex but still readable and documented cls: Type[_T], *, repr: bool, # noqa: A002 # shadowing 'repr' python builtin @@ -1180,7 +1183,8 @@ def _make_datamodel( # noqa: C901 # too complex but still readable and documen cls.__attrs_pre_init__ = cls.__pre_init__ # type: ignore[attr-defined] # adding new attribute if "__attrs_post_init__" in cls.__dict__ and not hasattr( - cls.__attrs_post_init__, _DATAMODEL_TAG # type: ignore[attr-defined] # mypy doesn't know about __attr_post_init__ + cls.__attrs_post_init__, + _DATAMODEL_TAG, # type: ignore[attr-defined] # mypy doesn't know about __attr_post_init__ ): raise TypeError(f"'{cls.__name__}' class contains forbidden custom '__attrs_post_init__'.") cls.__attrs_post_init__ = _make_post_init(has_post_init="__post_init__" in cls.__dict__) # type: ignore[attr-defined] # adding new attribute diff --git a/src/gt4py/eve/exceptions.py b/src/gt4py/eve/exceptions.py index 02379d8fc3..258c887c48 100644 --- a/src/gt4py/eve/exceptions.py +++ b/src/gt4py/eve/exceptions.py @@ -14,7 +14,6 @@ """Definitions of specific Eve exceptions.""" - from __future__ import annotations from .extended_typing import Any, Dict, Optional diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 82076d1a9c..4fa56cc264 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -180,11 +180,11 @@ def __get__( ) -> NonDataDescriptor[_C, _V]: ... @overload - def __get__( # noqa: F811 # redefinion of unused member + def __get__( # redefinion of unused member self, _instance: _C, _owner_type: Optional[Type[_C]] = None ) -> _V: ... - def __get__( # noqa: F811 # redefinion of unused member + def __get__( # redefinion of unused member self, _instance: Optional[_C], _owner_type: Optional[Type[_C]] = None ) -> _V | NonDataDescriptor[_C, _V]: ... @@ -351,7 +351,7 @@ def extended_runtime_checkable( ) -> _ProtoT: ... -def extended_runtime_checkable( # noqa: C901 # too complex but unavoidable +def extended_runtime_checkable( # too complex but unavoidable maybe_cls: Optional[_ProtoT] = None, *, instance_check_shortcut: bool = True, @@ -660,7 +660,7 @@ def eval_forward_ref( Examples: >>> from typing import Dict, Tuple - >>> print("Result:", eval_forward_ref('Dict[str, Tuple[int, float]]')) + >>> print("Result:", eval_forward_ref("Dict[str, Tuple[int, float]]")) Result: ...ict[str, ...uple[int, float]] """ @@ -697,7 +697,7 @@ class CallableKwargsInfo: data: Dict[str, Any] -def infer_type( # noqa: C901 # function is complex but well organized in independent cases +def infer_type( # function is complex but well organized in independent cases value: Any, *, annotate_callable_kwargs: bool = False, @@ -724,10 +724,10 @@ def infer_type( # noqa: C901 # function is complex but well organized in indep >>> infer_type(frozenset([1, 2, 3])) frozenset[int] - >>> infer_type({'a': 0, 'b': 1}) + >>> infer_type({"a": 0, "b": 1}) dict[str, int] - >>> infer_type({'a': 0, 'b': 'B'}) + >>> infer_type({"a": 0, "b": "B"}) dict[str, ...Any] >>> print("Result:", infer_type(lambda a, b: a + b)) @@ -755,7 +755,7 @@ def infer_type( # noqa: C901 # function is complex but well organized in indep ... @extended_infer_type.register(float) ... @extended_infer_type.register(complex) ... def _infer_type_number(value, *, annotate_callable_kwargs: bool = False): - ... return numbers.Number + ... return numbers.Number >>> extended_infer_type(3.4) >>> infer_type(3.4) diff --git a/src/gt4py/eve/pattern_matching.py b/src/gt4py/eve/pattern_matching.py index e239dddd70..f1c28b096b 100644 --- a/src/gt4py/eve/pattern_matching.py +++ b/src/gt4py/eve/pattern_matching.py @@ -14,7 +14,6 @@ """Basic pattern matching utilities.""" - from __future__ import annotations from functools import singledispatch @@ -31,9 +30,9 @@ class and all attributes of the pattern (recursively) match the Examples: >>> class Foo: - ... def __init__(self, bar, baz): - ... self.bar = bar - ... self.baz = baz + ... def __init__(self, bar, baz): + ... self.bar = bar + ... self.baz = baz >>> assert ObjectPattern(Foo, bar=1).match(Foo(1, 2)) """ @@ -53,16 +52,16 @@ def match(self, other: Any, *, raise_exception: bool = False) -> bool: if raise_exception: diffs = [*get_differences(self, other)] if len(diffs) > 0: - diffs_str = "\n ".join( - [f" {self.cls.__name__}{path}: {msg}" for path, msg in diffs] - ) + diffs_str = "\n ".join([ + f" {self.cls.__name__}{path}: {msg}" for path, msg in diffs + ]) raise ValueError(f"Object and pattern don't match:\n {diffs_str}") return True return next(get_differences(self, other), None) is None def __str__(self) -> str: - attrs_str = ", ".join([f"{str(k)}={str(v)}" for k, v in self.fields.items()]) + attrs_str = ", ".join([f"{k!s}={v!s}" for k, v in self.fields.items()]) return f"{self.cls.__name__}({attrs_str})" diff --git a/src/gt4py/eve/traits.py b/src/gt4py/eve/traits.py index aacae804d8..ff8049c91d 100644 --- a/src/gt4py/eve/traits.py +++ b/src/gt4py/eve/traits.py @@ -14,7 +14,6 @@ """Definitions of node and visitor trait classes.""" - from __future__ import annotations import collections diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index 7bfd22cdf7..dabb48105b 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -14,7 +14,6 @@ """Iterator utils.""" - from __future__ import annotations import abc diff --git a/src/gt4py/eve/type_definitions.py b/src/gt4py/eve/type_definitions.py index 1ee981f548..d2cf87c644 100644 --- a/src/gt4py/eve/type_definitions.py +++ b/src/gt4py/eve/type_definitions.py @@ -14,16 +14,15 @@ """Definitions of useful field and general types.""" - from __future__ import annotations import abc import re import sys -from enum import Enum as Enum, IntEnum as IntEnum # noqa: F401 # imported but unused +from enum import Enum as Enum, IntEnum as IntEnum # imported but unused -from boltons.typeutils import classproperty as classproperty # noqa: F401 -from frozendict import frozendict as _frozendict # noqa: F401 +from boltons.typeutils import classproperty as classproperty +from frozendict import frozendict as _frozendict from .extended_typing import ( Any, @@ -43,7 +42,7 @@ _Tc = TypeVar("_Tc", covariant=True) -class FrozenList(Tuple[_Tc, ...], metaclass=abc.ABCMeta): # noqa: B024 # no abstract methods +class FrozenList(Tuple[_Tc, ...], metaclass=abc.ABCMeta): # no abstract methods """Tuple subtype which works as an alias of ``Tuple[_Tc, ...]``.""" __slots__ = () @@ -98,7 +97,8 @@ class ConstrainedStr(str): class keyword argument or as class variable. Examples: - >>> class OnlyLetters(ConstrainedStr, regex=re.compile(r"^[a-zA-Z]*$")): pass + >>> class OnlyLetters(ConstrainedStr, regex=re.compile(r"^[a-zA-Z]*$")): + ... pass >>> OnlyLetters("aabbCC") OnlyLetters('aabbCC') diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 124957fa20..f67741c658 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -14,7 +14,6 @@ """Generic interface and implementations of run-time type validation for arbitrary values.""" - from __future__ import annotations import abc @@ -113,7 +112,7 @@ def __call__( ) -> FixedTypeValidator: ... @overload - def __call__( # noqa: F811 # redefinion of unused member + def __call__( # redefinion of unused member self, type_annotation: TypeAnnotation, name: Optional[str] = None, @@ -125,7 +124,7 @@ def __call__( # noqa: F811 # redefinion of unused member ) -> Optional[FixedTypeValidator]: ... @abc.abstractmethod - def __call__( # noqa: F811 # redefinion of unused member + def __call__( # redefinion of unused member self, type_annotation: TypeAnnotation, name: Optional[str] = None, @@ -170,7 +169,7 @@ def __call__( ) -> FixedTypeValidator: ... @overload - def __call__( # noqa: F811 # redefinion of unused member + def __call__( # redefinion of unused member self, type_annotation: TypeAnnotation, name: Optional[str] = None, @@ -181,7 +180,7 @@ def __call__( # noqa: F811 # redefinion of unused member **kwargs: Any, ) -> Optional[FixedTypeValidator]: ... - def __call__( # noqa: F811,C901 # redefinion of unused member / complex but well organized in cases + def __call__( # redefinion of unused member / complex but well organized in cases self, type_annotation: TypeAnnotation, name: Optional[str] = None, diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8e634c4b11..5a89e83b74 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -14,7 +14,6 @@ """General utility functions. Some functionalities are directly imported from dependencies.""" - from __future__ import annotations import collections.abc @@ -35,12 +34,12 @@ import deepdiff import xxhash -from boltons.iterutils import ( # noqa: F401 +from boltons.iterutils import ( flatten as flatten, flatten_iter as flatten_iter, is_collection as is_collection, ) -from boltons.strutils import ( # noqa: F401 +from boltons.strutils import ( a10n as a10n, asciify as asciify, format_int_list as format_int_list, @@ -79,7 +78,7 @@ import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz - import toolz # noqa: F401 + import toolz T = TypeVar("T") @@ -92,7 +91,7 @@ def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], >>> checker = isinstancechecker((int, str)) >>> checker(3) True - >>> checker('3') + >>> checker("3") True >>> checker(3.3) False @@ -117,17 +116,17 @@ def attrchecker(*names: str) -> Callable[[Any], bool]: Examples: >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> point = Point(1.0, 2.0) - >>> checker = attrchecker('x') + >>> checker = attrchecker("x") >>> checker(point) True - >>> checker = attrchecker('x', 'y') + >>> checker = attrchecker("x", "y") >>> checker(point) True - >>> checker = attrchecker('z') + >>> checker = attrchecker("z") >>> checker(point) False @@ -144,19 +143,19 @@ def attrgetter_(*names: str, default: Any = NOTHING) -> Callable[[Any], Any]: Examples: >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> point = Point(1.0, 2.0) - >>> getter = attrgetter_('x') + >>> getter = attrgetter_("x") >>> getter(point) 1.0 >>> import math - >>> getter = attrgetter_('z', default=math.nan) + >>> getter = attrgetter_("z", default=math.nan) >>> getter(point) nan >>> import math - >>> getter = attrgetter_('x', 'y', 'z', default=math.nan) + >>> getter = attrgetter_("x", "y", "z", default=math.nan) >>> getter(point) (1.0, 2.0, nan) @@ -187,12 +186,12 @@ def getitem_(obj: Any, key: Any, default: Any = NOTHING) -> Any: Similar to :func:`operator.getitem()` but accepts a default value. Examples: - >>> d = {'a': 1} - >>> getitem_(d, 'a') + >>> d = {"a": 1} + >>> getitem_(d, "a") 1 - >>> d = {'a': 1} - >>> getitem_(d, 'b', 'default') + >>> d = {"a": 1} + >>> getitem_(d, "b", "default") 'default' """ @@ -213,13 +212,13 @@ def itemgetter_(key: Any, default: Any = NOTHING) -> Callable[[Any], Any]: Similar to :func:`operator.itemgetter()` but accepts a default value. Examples: - >>> d = {'a': 1} - >>> getter = itemgetter_('a') + >>> d = {"a": 1} + >>> getter = itemgetter_("a") >>> getter(d) 1 - >>> d = {'a': 1} - >>> getter = itemgetter_('b', 'default') + >>> d = {"a": 1} + >>> getter = itemgetter_("b", "default") >>> getter(d) 'default' @@ -245,12 +244,12 @@ def with_fluid_partial( @overload -def with_fluid_partial( # noqa: F811 # redefinition of unused function +def with_fluid_partial( # redefinition of unused function func: Callable[_P, _T], *args: Any, **kwargs: Any ) -> Callable[_P, _T]: ... -def with_fluid_partial( # noqa: F811 # redefinition of unused function +def with_fluid_partial( # redefinition of unused function func: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any ) -> Union[Callable[..., Any], Callable[[Callable[..., Any]], Callable[..., Any]]]: """Add a `partial` attribute to the decorated function. @@ -269,7 +268,6 @@ def with_fluid_partial( # noqa: F811 # redefinition of unused function >>> @with_fluid_partial ... def add(a, b): ... return a + b - ... >>> add.partial(1)(2) 3 """ @@ -288,12 +286,12 @@ def optional_lru_cache( @overload -def optional_lru_cache( # noqa: F811 # redefinition of unused function +def optional_lru_cache( # redefinition of unused function func: Callable[_P, _T], *, maxsize: Optional[int] = 128, typed: bool = False ) -> Callable[_P, _T]: ... -def optional_lru_cache( # noqa: F811 # redefinition of unused function +def optional_lru_cache( # redefinition of unused function func: Optional[Callable[_P, _T]] = None, *, maxsize: Optional[int] = 128, typed: bool = False ) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]: """Wrap :func:`functools.lru_cache` to fall back to the original function if arguments are not hashable. @@ -303,7 +301,6 @@ def optional_lru_cache( # noqa: F811 # redefinition of unused function ... def func(a, b): ... print(f"Inside func({a}, {b})") ... return a + b - ... >>> print(func(1, 3)) Inside func(1, 3) 4 @@ -346,14 +343,11 @@ def register_subclasses(*subclasses: Type) -> Callable[[Type], Type]: >>> import abc >>> class MyVirtualSubclassA: ... pass - ... >>> class MyVirtualSubclassB: - ... pass - ... + ... pass >>> @register_subclasses(MyVirtualSubclassA, MyVirtualSubclassB) ... class MyBaseClass(abc.ABC): - ... pass - ... + ... pass >>> issubclass(MyVirtualSubclassA, MyBaseClass) and issubclass(MyVirtualSubclassB, MyBaseClass) True @@ -749,7 +743,7 @@ def __setattr__(self, name: str, value: Any) -> None: def __iter__(self) -> Iterator[T]: return self.iterator - def map(self, func: Callable) -> XIterable[Any]: # noqa # A003: shadowing a python builtin + def map(self, func: Callable) -> XIterable[Any]: # A003: shadowing a python builtin """Apply a callable to every iterator element. Equivalent to ``map(func, self)``. @@ -789,7 +783,7 @@ def map(self, func: Callable) -> XIterable[Any]: # noqa # A003: shadowing a py raise ValueError(f"Invalid function or callable: '{func}'.") return XIterable(map(func, self.iterator)) - def filter( # noqa # A003: shadowing a python builtin + def filter( # A003: shadowing a python builtin self, func: Callable[..., bool] ) -> XIterable[T]: """Filter elements with callables. @@ -822,7 +816,7 @@ def if_isinstance(self, *types: Type) -> XIterable[T]: Equivalent to ``xiter(item for item in self if isinstance(item, types))``. Examples: - >>> it = xiter([1, '2', 3.3, [4, 5], {6, 7}]) + >>> it = xiter([1, "2", 3.3, [4, 5], {6, 7}]) >>> list(it.if_isinstance(int, float)) [1, 3.3] @@ -835,7 +829,7 @@ def if_not_isinstance(self, *types: Type) -> XIterable[T]: Equivalent to ``xiter(item for item in self if not isinstance(item, types))``. Examples: - >>> it = xiter([1, '2', 3.3, [4, 5], {6, 7}]) + >>> it = xiter([1, "2", 3.3, [4, 5], {6, 7}]) >>> list(it.if_not_isinstance(int, float)) ['2', [4, 5], {6, 7}] @@ -942,18 +936,18 @@ def if_hasattr(self, *names: str) -> XIterable[T]: Equivalent to ``filter(attrchecker(names), self)``. Examples: - >>> it = xiter([1, '2', 3.3, [4, 5], {6, 7}]) - >>> list(it.if_hasattr('__len__')) + >>> it = xiter([1, "2", 3.3, [4, 5], {6, 7}]) + >>> list(it.if_hasattr("__len__")) ['2', [4, 5], {6, 7}] - >>> it = xiter([1, '2', 3.3, [4, 5], {6, 7}]) - >>> list(it.if_hasattr('__len__', 'index')) + >>> it = xiter([1, "2", 3.3, [4, 5], {6, 7}]) + >>> list(it.if_hasattr("__len__", "index")) ['2', [4, 5]] """ return XIterable(filter(attrchecker(*names), self.iterator)) - def getattr( # noqa # A003: shadowing a python builtin + def getattr( # A003: shadowing a python builtin self, *names: str, default: Any = NOTHING ) -> XIterable[Any]: """Get provided attributes from each item in a sequence. @@ -968,13 +962,13 @@ def getattr( # noqa # A003: shadowing a python builtin Examples: >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y']) + >>> Point = namedtuple("Point", ["x", "y"]) >>> it = xiter([Point(1.0, -1.0), Point(2.0, -2.0), Point(3.0, -3.0)]) - >>> list(it.getattr('y')) + >>> list(it.getattr("y")) [-1.0, -2.0, -3.0] >>> it = xiter([Point(1.0, -1.0), Point(2.0, -2.0), Point(3.0, -3.0)]) - >>> list(it.getattr('x', 'z', default=None)) + >>> list(it.getattr("x", "z", default=None)) [(1.0, None), (2.0, None), (3.0, None)] """ @@ -991,7 +985,7 @@ def getitem(self, *indices: Union[int, str], default: Any = NOTHING) -> XIterabl For detailed information check :func:`toolz.itertoolz.pluck` reference. - >>> it = xiter([('a', 1), ('b', 2), ('c', 3)]) + >>> it = xiter([("a", 1), ("b", 2), ("c", 3)]) >>> list(it.getitem(0)) ['a', 'b', 'c'] @@ -999,7 +993,7 @@ def getitem(self, *indices: Union[int, str], default: Any = NOTHING) -> XIterabl ... dict(name="AA", age=20, country="US"), ... dict(name="BB", age=30, country="UK"), ... dict(name="CC", age=40, country="EU"), - ... dict(country="CH") + ... dict(country="CH"), ... ]) >>> list(it.getitem("name", "age", default=None)) [('AA', 20), ('BB', 30), ('CC', 40), (None, None)] @@ -1023,12 +1017,12 @@ def chain(self, *others: Iterable) -> XIterable[Union[T, S]]: For detailed information check :func:`itertools.chain` reference. Examples: - >>> it_a, it_b = xiter(range(2)), xiter(['a', 'b']) + >>> it_a, it_b = xiter(range(2)), xiter(["a", "b"]) >>> list(it_a.chain(it_b)) [0, 1, 'a', 'b'] >>> it_a = xiter(range(2)) - >>> list(it_a.chain(['a', 'b'], ['A', 'B'])) + >>> list(it_a.chain(["a", "b"], ["A", "B"])) [0, 1, 'a', 'b', 'A', 'B'] """ @@ -1092,7 +1086,7 @@ def product( For detailed information check :func:`itertools.product` reference. Examples: - >>> it_a, it_b = xiter([0, 1]), xiter(['a', 'b']) + >>> it_a, it_b = xiter([0, 1]), xiter(["a", "b"]) >>> list(it_a.product(it_b)) [(0, 'a'), (0, 'b'), (1, 'a'), (1, 'b')] @@ -1175,7 +1169,7 @@ def take_nth(self, n: int) -> XIterable[T]: raise ValueError(f"Only positive integer numbers are accepted (provided: {n}).") return XIterable(toolz.itertoolz.take_nth(n, self.iterator)) - def zip( # noqa # A003: shadowing a python builtin + def zip( # A003: shadowing a python builtin self, *others: Iterable, fill: Any = NOTHING ) -> XIterable[Tuple[T, S]]: """Zip iterators. @@ -1189,16 +1183,16 @@ def zip( # noqa # A003: shadowing a python builtin Examples: >>> it_a = xiter(range(3)) - >>> it_b = ['a', 'b', 'c'] + >>> it_b = ["a", "b", "c"] >>> list(it_a.zip(it_b)) [(0, 'a'), (1, 'b'), (2, 'c')] >>> it = xiter(range(3)) - >>> list(it.zip(['a', 'b', 'c'], ['A', 'B', 'C'])) + >>> list(it.zip(["a", "b", "c"], ["A", "B", "C"])) [(0, 'a', 'A'), (1, 'b', 'B'), (2, 'c', 'C')] >>> it = xiter(range(5)) - >>> list(it.zip(['a', 'b', 'c'], ['A', 'B', 'C'], fill=None)) + >>> list(it.zip(["a", "b", "c"], ["A", "B", "C"], fill=None)) [(0, 'a', 'A'), (1, 'b', 'B'), (2, 'c', 'C'), (3, None, None), (4, None, None)] """ @@ -1216,7 +1210,7 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: For detailed information check :func:`zip` reference. Examples: - >>> it = xiter([('a', 1), ('b', 2), ('c', 3)]) + >>> it = xiter([("a", 1), ("b", 2), ("c", 3)]) >>> list(it.unzip()) [('a', 'b', 'c'), (1, 2, 3)] @@ -1296,7 +1290,7 @@ def unique(self, *, key: Union[NOTHING, Callable] = NOTHING) -> XIterable[T]: >>> list(it.unique()) [1, 2, 3] - >>> it = xiter(['cat', 'mouse', 'dog', 'hen']) + >>> it = xiter(["cat", "mouse", "dog", "hen"]) >>> list(it.unique(key=len)) ['cat', 'mouse'] @@ -1347,29 +1341,29 @@ def groupby( For detailed information check :func:`toolz.itertoolz.groupby` reference. Examples: - >>> it = xiter([(1.0, -1.0), (1.0,-2.0), (2.2, -3.0)]) + >>> it = xiter([(1.0, -1.0), (1.0, -2.0), (2.2, -3.0)]) >>> list(it.groupby([0])) [(1.0, [(1.0, -1.0), (1.0, -2.0)]), (2.2, [(2.2, -3.0)])] >>> data = [ - ... {'x': 1.0, 'y': -1.0, 'z': 1.0}, - ... {'x': 1.0, 'y': -2.0, 'z': 1.0}, - ... {'x': 2.2, 'y': -3.0, 'z': 2.2} + ... {"x": 1.0, "y": -1.0, "z": 1.0}, + ... {"x": 1.0, "y": -2.0, "z": 1.0}, + ... {"x": 2.2, "y": -3.0, "z": 2.2}, ... ] - >>> list(xiter(data).groupby(['x'])) + >>> list(xiter(data).groupby(["x"])) [(1.0, [{'x': 1.0, 'y': -1.0, 'z': 1.0}, {'x': 1.0, 'y': -2.0, 'z': 1.0}]), (2.2, [{'x': 2.2, 'y': -3.0, 'z': 2.2}])] - >>> list(xiter(data).groupby(['x', 'z'])) + >>> list(xiter(data).groupby(["x", "z"])) [((1.0, 1.0), [{'x': 1.0, 'y': -1.0, 'z': 1.0}, {'x': 1.0, 'y': -2.0, 'z': 1.0}]), ((2.2, 2.2), [{'x': 2.2, 'y': -3.0, 'z': 2.2}])] >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y', 'z']) + >>> Point = namedtuple("Point", ["x", "y", "z"]) >>> data = [Point(1.0, -2.0, 1.0), Point(1.0, -2.0, 1.0), Point(2.2, 3.0, 2.0)] - >>> list(xiter(data).groupby('x')) + >>> list(xiter(data).groupby("x")) [(1.0, [Point(x=1.0, y=-2.0, z=1.0), Point(x=1.0, y=-2.0, z=1.0)]), (2.2, [Point(x=2.2, y=3.0, z=2.0)])] - >>> list(xiter(data).groupby('x', 'z')) + >>> list(xiter(data).groupby("x", "z")) [((1.0, 1.0), [Point(x=1.0, y=-2.0, z=1.0), Point(x=1.0, y=-2.0, z=1.0)]), ((2.2, 2.0), [Point(x=2.2, y=3.0, z=2.0)])] - >>> it = xiter(['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']) + >>> it = xiter(["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]) >>> list(it.groupby(len)) [(5, ['Alice', 'Edith', 'Frank']), (3, ['Bob', 'Dan']), (7, ['Charlie'])] @@ -1432,8 +1426,12 @@ def reduce(self, bin_op_func: Callable[[Any, T], Any], *, init: Any = None) -> A >>> it.reduce((lambda accu, i: accu + i), init=0) 10 - >>> it = xiter(['a', 'b', 'c', 'd', 'e']) - >>> sorted(it.reduce((lambda accu, item: (accu or set()) | {item} if item in 'aeiou' else accu))) + >>> it = xiter(["a", "b", "c", "d", "e"]) + >>> sorted( + ... it.reduce( + ... (lambda accu, item: (accu or set()) | {item} if item in "aeiou" else accu) + ... ) + ... ) ['a', 'e'] """ @@ -1558,33 +1556,37 @@ def reduceby( For detailed information check :func:`toolz.itertoolz.reduceby` reference. Examples: - >>> it = xiter([(1.0, -1.0), (1.0,-2.0), (2.2, -3.0)]) + >>> it = xiter([(1.0, -1.0), (1.0, -2.0), (2.2, -3.0)]) >>> list(it.reduceby((lambda accu, _: accu + 1), [0], init=0)) [(1.0, 2), (2.2, 1)] >>> data = [ - ... {'x': 1.0, 'y': -1.0, 'z': 1.0}, - ... {'x': 1.0, 'y': -2.0, 'z': 1.0}, - ... {'x': 2.2, 'y': -3.0, 'z': 2.2} + ... {"x": 1.0, "y": -1.0, "z": 1.0}, + ... {"x": 1.0, "y": -2.0, "z": 1.0}, + ... {"x": 2.2, "y": -3.0, "z": 2.2}, ... ] - >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), ['x'], init=0)) + >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), ["x"], init=0)) [(1.0, 2), (2.2, 1)] - >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), ['x', 'z'], init=0)) + >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), ["x", "z"], init=0)) [((1.0, 1.0), 2), ((2.2, 2.2), 1)] >>> from collections import namedtuple - >>> Point = namedtuple('Point', ['x', 'y', 'z']) + >>> Point = namedtuple("Point", ["x", "y", "z"]) >>> data = [Point(1.0, -2.0, 1.0), Point(1.0, -2.0, 1.0), Point(2.2, 3.0, 2.0)] - >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), 'x', init=0)) + >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), "x", init=0)) [(1.0, 2), (2.2, 1)] - >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), 'x', 'z', init=0)) + >>> list(xiter(data).reduceby((lambda accu, _: accu + 1), "x", "z", init=0)) [((1.0, 1.0), 2), ((2.2, 2.0), 1)] - >>> it = xiter(['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']) - >>> list(it.reduceby(lambda nvowels, name: nvowels + sum(i in 'aeiou' for i in name), len, init=0)) + >>> it = xiter(["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"]) + >>> list( + ... it.reduceby( + ... lambda nvowels, name: nvowels + sum(i in "aeiou" for i in name), len, init=0 + ... ) + ... ) [(5, 4), (3, 2), (7, 3)] - """ # noqa: RST203, RST301 # sphinx.napoleon conventions confuse RST validator + """ # sphinx.napoleon conventions confuse RST validator if (not callable(key) and not isinstance(key, (int, str, list))) or not all( isinstance(i, str) for i in attr_keys ): diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index fe5f9e1474..750505becf 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -14,7 +14,6 @@ """Visitor classes to work with IR trees.""" - from __future__ import annotations import collections.abc diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index d7598d7dd8..a556582314 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -167,10 +167,10 @@ def __repr__(self) -> str: def __getitem__(self, index: int) -> int: ... @overload - def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unused + def __getitem__(self, index: slice) -> UnitRange: # redefine unused ... - def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # redefine unused + def __getitem__(self, index: int | slice) -> int | UnitRange: # redefine unused assert UnitRange.is_finite(self) if isinstance(index, slice): start, stop, step = index.indices(len(self)) @@ -275,9 +275,7 @@ def unit_range(r: RangeLike) -> UnitRange: NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType -NamedSlice: TypeAlias = ( - slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/ -) +NamedSlice: TypeAlias = slice # once slice is generic we should do: slice[NamedIndex, NamedIndex, Literal[1]], see https://peps.python.org/pep-0696/ AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange | NamedSlice AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] @@ -431,17 +429,17 @@ def is_empty(self) -> bool: def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... @overload - def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused + def __getitem__(self, index: slice) -> Self: # redefine unused ... @overload - def __getitem__( # noqa: F811 # redefine unused + def __getitem__( # redefine unused self, index: Dimension ) -> tuple[Dimension, _Rng]: ... - def __getitem__( # noqa: F811 # redefine unused + def __getitem__( # redefine unused self, index: int | slice | Dimension - ) -> NamedRange | Domain: # noqa: F811 # redefine unused + ) -> NamedRange | Domain: # redefine unused if isinstance(index, int): return self.dims[index], self.ranges[index] elif isinstance(index, slice): @@ -1006,11 +1004,11 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: >>> I, J, K = (Dimension(value=dim) for dim in ["I", "J", "K"]) >>> promote_dims([I, J], [I, J, K]) == [I, J, K] True - >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS + >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS Traceback (most recent call last): ... ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. - >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS + >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS Traceback (most recent call last): ... ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index f8e7b9bff8..47f2bc3264 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -262,12 +262,10 @@ def as_field( raise ValueError(f"Origin keys {unknown_dims} not in domain {domain}.") else: origin = {} - actual_domain = common.domain( - [ - (d, (-(start_offset := origin.get(d, 0)), s - start_offset)) - for d, s in zip(domain, data.shape) - ] - ) + actual_domain = common.domain([ + (d, (-(start_offset := origin.get(d, 0)), s - start_offset)) + for d, s in zip(domain, data.shape) + ]) else: if origin: raise ValueError(f"Cannot specify origin for domain {domain}") diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 022899d23e..40624ebd73 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -106,7 +106,9 @@ def domain_intersection( Example: >>> I = common.Dimension("I") - >>> domain_intersection(common.domain({I:(0,5)}), common.domain({I:(1,3)})) # doctest: +ELLIPSIS + >>> domain_intersection( + ... common.domain({I: (0, 5)}), common.domain({I: (1, 3)}) + ... ) # doctest: +ELLIPSIS Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),)) """ return functools.reduce( @@ -127,23 +129,23 @@ def intersect_domains( Example: >>> I = common.Dimension("I") >>> J = common.Dimension("J") - >>> res = intersect_domains(common.domain({I:(0,5), J:(1,2)}), common.domain({I:(1,3), J:(0,3)}), ignore_dims=J) - >>> assert res == (common.domain({I:(1,3), J:(1,2)}), common.domain({I:(1,3), J:(0,3)})) + >>> res = intersect_domains( + ... common.domain({I: (0, 5), J: (1, 2)}), + ... common.domain({I: (1, 3), J: (0, 3)}), + ... ignore_dims=J, + ... ) + >>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)})) """ ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,) - intersection_without_ignore_dims = domain_intersection( - *[ - common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple]) - for domain in domains - ] - ) + intersection_without_ignore_dims = domain_intersection(*[ + common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple]) + for domain in domains + ]) return tuple( - common.Domain( - *[ - (d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1]) - for d, r in domain - ] - ) + common.Domain(*[ + (d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1]) + for d, r in domain + ]) for domain in domains ) @@ -206,7 +208,12 @@ def _named_slice_to_named_range( ) -> common.NamedRange | common.NamedSlice: assert hasattr(idx, "start") and hasattr(idx, "stop") if common.is_named_slice(idx): - idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = idx.start[0], idx.start[1], idx.stop[0], idx.stop[1] # type: ignore[attr-defined] + idx_start_0, idx_start_1, idx_stop_0, idx_stop_1 = ( + idx.start[0], + idx.start[1], + idx.stop[0], + idx.stop[1], + ) # type: ignore[attr-defined] if idx_start_0 != idx_stop_0: raise IndexError( f"Dimensions slicing mismatch between '{idx_start_0.value}' and '{idx_stop_0.value}'." diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 1cb0cca8b2..6e9e72ffed 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -61,9 +61,9 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: xp = cls_.array_ns op = getattr(xp, array_builtin_name) - domain_intersection = embedded_common.domain_intersection( - *[f.domain for f in fields if common.is_field(f)] - ) + domain_intersection = embedded_common.domain_intersection(*[ + f.domain for f in fields if common.is_field(f) + ]) transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = [] for f in fields: @@ -497,10 +497,12 @@ def _hypercube( # -- Specialized implementations for builtin operations on array fields -- NdArrayField.register_builtin_func( - fbuiltins.abs, NdArrayField.__abs__ # type: ignore[attr-defined] + fbuiltins.abs, + NdArrayField.__abs__, # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.power, NdArrayField.__pow__ # type: ignore[attr-defined] + fbuiltins.power, + NdArrayField.__pow__, # type: ignore[attr-defined] ) # TODO gamma @@ -514,13 +516,16 @@ def _hypercube( NdArrayField.register_builtin_func(getattr(fbuiltins, name), _make_builtin(name, name)) NdArrayField.register_builtin_func( - fbuiltins.minimum, _make_builtin("minimum", "minimum") # type: ignore[attr-defined] + fbuiltins.minimum, + _make_builtin("minimum", "minimum"), # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.maximum, _make_builtin("maximum", "maximum") # type: ignore[attr-defined] + fbuiltins.maximum, + _make_builtin("maximum", "maximum"), # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.fmod, _make_builtin("fmod", "fmod") # type: ignore[attr-defined] + fbuiltins.fmod, + _make_builtin("fmod", "fmod"), # type: ignore[attr-defined] ) NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 47695d7c4b..328a32396a 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -42,7 +42,9 @@ class ScanOperator(EmbeddedOperator[_R, _P]): init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...] axis: common.Dimension - def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun + def __call__( + self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar + ) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun scan_range = embedded_context.closure_column_range.get() assert self.axis == scan_range[0] scan_axis = scan_range[0] @@ -50,9 +52,9 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel domain_intersection = _intersect_scan_args(*all_args) non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) - out_domain = common.Domain( - *[scan_range if nr[0] == scan_axis else nr for nr in domain_intersection] - ) + out_domain = common.Domain(*[ + scan_range if nr[0] == scan_axis else nr for nr in domain_intersection + ]) if scan_axis not in out_domain.dims: # even if the scan dimension is not in the input, we can scan over it out_domain = common.Domain(*out_domain, (scan_range)) @@ -80,11 +82,11 @@ def scan_loop(hpos): def _get_out_domain( - out: common.MutableField | tuple[common.MutableField | tuple, ...] + out: common.MutableField | tuple[common.MutableField | tuple, ...], ) -> common.Domain: - return embedded_common.domain_intersection( - *[f.domain for f in utils.flatten_nested_tuple((out,))] - ) + return embedded_common.domain_intersection(*[ + f.domain for f in utils.flatten_nested_tuple((out,)) + ]) def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): @@ -148,15 +150,15 @@ def impl(target: common.MutableField, source: common.Field): def _intersect_scan_args( - *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], ) -> common.Domain: - return embedded_common.domain_intersection( - *[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)] - ) + return embedded_common.domain_intersection(*[ + arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg) + ]) def _get_array_ns( - *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...], ) -> ModuleType: for arg in utils.flatten_nested_tuple(args): if hasattr(arg, "array_ns"): diff --git a/src/gt4py/next/ffront/ast_passes/remove_docstrings.py b/src/gt4py/next/ffront/ast_passes/remove_docstrings.py index 653456f6c5..afa8b730b1 100644 --- a/src/gt4py/next/ffront/ast_passes/remove_docstrings.py +++ b/src/gt4py/next/ffront/ast_passes/remove_docstrings.py @@ -27,17 +27,15 @@ class RemoveDocstrings(ast.NodeTransformer): >>> def example_docstring(): ... a = 1 ... "This is a docstring" + ... ... def example_docstring_2(): - ... a = 2.0 - ... "This is a new docstring" - ... return a + ... a = 2.0 + ... "This is a new docstring" + ... return a + ... ... a = example_docstring_2() ... return a - >>> print(ast.unparse( - ... RemoveDocstrings.apply( - ... ast.parse(inspect.getsource(example_docstring)) - ... ) - ... )) + >>> print(ast.unparse(RemoveDocstrings.apply(ast.parse(inspect.getsource(example_docstring))))) def example_docstring(): a = 1 diff --git a/src/gt4py/next/ffront/ast_passes/simple_assign.py b/src/gt4py/next/ffront/ast_passes/simple_assign.py index 8b079bb8c1..966b234e79 100644 --- a/src/gt4py/next/ffront/ast_passes/simple_assign.py +++ b/src/gt4py/next/ffront/ast_passes/simple_assign.py @@ -61,11 +61,7 @@ class SingleAssignTargetPass(NodeYielder): ... a = b = 1 ... return a, b >>> - >>> print(ast.unparse( - ... SingleAssignTargetPass.apply( - ... ast.parse(inspect.getsource(foo)) - ... ) - ... )) + >>> print(ast.unparse(SingleAssignTargetPass.apply(ast.parse(inspect.getsource(foo))))) def foo(): __sat_tmp0 = 1 a = __sat_tmp0 diff --git a/src/gt4py/next/ffront/ast_passes/single_static_assign.py b/src/gt4py/next/ffront/ast_passes/single_static_assign.py index ee1e29a8e8..02545e360b 100644 --- a/src/gt4py/next/ffront/ast_passes/single_static_assign.py +++ b/src/gt4py/next/ffront/ast_passes/single_static_assign.py @@ -107,11 +107,7 @@ class SingleStaticAssignPass(ast.NodeTransformer): ... a = 3 + a ... return a - >>> print(ast.unparse( - ... SingleStaticAssignPass.apply( - ... ast.parse(inspect.getsource(foo)) - ... ) - ... )) + >>> print(ast.unparse(SingleStaticAssignPass.apply(ast.parse(inspect.getsource(foo))))) def foo(): aᐞ0 = 1 aᐞ1 = 2 + aᐞ0 diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 0cf1611bb1..a152439a65 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -514,17 +514,17 @@ def program( Examples: >>> @program # noqa: F821 # doctest: +SKIP - ... def program(in_field: Field[[TDim], float64], out_field: Field[[TDim], float64]): # noqa: F821 + ... def program(in_field: Field[[TDim], float64], out_field: Field[[TDim], float64]): # noqa: F821 ... field_op(in_field, out=out_field) - >>> program(in_field, out=out_field) # noqa: F821 # doctest: +SKIP + >>> program(in_field, out=out_field) # noqa: F821 # doctest: +SKIP >>> # the backend can optionally be passed if already decided >>> # not passing it will result in embedded execution by default >>> # the above is equivalent to >>> @program(backend="roundtrip") # noqa: F821 # doctest: +SKIP - ... def program(in_field: Field[[TDim], float64], out_field: Field[[TDim], float64]): # noqa: F821 + ... def program(in_field: Field[[TDim], float64], out_field: Field[[TDim], float64]): # noqa: F821 ... field_op(in_field, out=out_field) - >>> program(in_field, out=out_field) # noqa: F821 # doctest: +SKIP + >>> program(in_field, out=out_field) # noqa: F821 # doctest: +SKIP """ def program_inner(definition: types.FunctionType) -> Program: @@ -624,7 +624,7 @@ def with_grid_type(self, grid_type: GridType) -> FieldOperator: def __gt_itir__(self) -> itir.FunctionDefinition: if hasattr(self, "__cached_itir"): - return getattr(self, "__cached_itir") # noqa: B009 + return getattr(self, "__cached_itir") itir_node: itir.FunctionDefinition = FieldOperatorLowering.apply(self.foast_node) @@ -642,9 +642,10 @@ def as_program( # of arg and kwarg types # TODO(tehrengruber): check foast operator has no out argument that clashes # with the out argument of the program we generate here. - hash_ = eve_utils.content_hash( - (tuple(arg_types), tuple((name, arg) for name, arg in kwarg_types.items())) - ) + hash_ = eve_utils.content_hash(( + tuple(arg_types), + tuple((name, arg) for name, arg in kwarg_types.items()), + )) try: return self._program_cache[hash_] except KeyError: @@ -773,14 +774,14 @@ def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None): Examples: >>> @field_operator # doctest: +SKIP - ... def field_op(in_field: Field[[TDim], float64]) -> Field[[TDim], float64]: # noqa: F821 + ... def field_op(in_field: Field[[TDim], float64]) -> Field[[TDim], float64]: # noqa: F821 ... ... >>> field_op(in_field, out=out_field) # noqa: F821 # doctest: +SKIP >>> # the backend can optionally be passed if already decided >>> # not passing it will result in embedded execution by default >>> @field_operator(backend="roundtrip") # doctest: +SKIP - ... def field_op(in_field: Field[[TDim], float64]) -> Field[[TDim], float64]: # noqa: F821 + ... def field_op(in_field: Field[[TDim], float64]) -> Field[[TDim], float64]: # noqa: F821 ... ... """ @@ -846,9 +847,9 @@ def scan_operator( >>> KDim = gtx.Dimension("K", kind=gtx.DimensionKind.VERTICAL) >>> inp = gtx.as_field([KDim], np.ones((10,))) >>> out = gtx.as_field([KDim], np.zeros((10,))) - >>> @gtx.scan_operator(axis=KDim, forward=True, init=0.) + >>> @gtx.scan_operator(axis=KDim, forward=True, init=0.0) ... def scan_operator(carry: float, val: float) -> float: - ... return carry+val + ... return carry + val >>> scan_operator(inp, out=out, offset_provider={}) # doctest: +SKIP >>> out.array() # doctest: +SKIP array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 3c4b935224..544fd1c054 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -15,7 +15,7 @@ import dataclasses import functools import inspect -from builtins import bool, float, int, tuple # noqa: A004 +from builtins import bool, float, int, tuple from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np @@ -146,9 +146,7 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: raise ValueError( "Tuple of different size not allowed." ) # TODO(havogt) find a strategy to unify parsing and embedded error messages - return tuple( - where(mask, t, f) for t, f in zip(true_field, false_field) - ) # type: ignore[return-value] # `tuple` is not `_R` + return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(mask, true_field, false_field) @@ -256,8 +254,8 @@ def astype( def _make_unary_math_builtin(name): def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: # TODO(havogt): enable once we have a failing test (see `test_math_builtin_execution.py`) - # assert core_defs.is_scalar_type(value) # default implementation for scalars, Fields are handled via dispatch # noqa: E800 # commented code - # return getattr(math, name)(value)# noqa: E800 # commented code + # assert core_defs.is_scalar_type(value) # default implementation for scalars, Fields are handled via dispatch # commented code + # return getattr(math, name)(value)# commented code raise NotImplementedError() impl.__name__ = name diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 322a6df2e0..dde6211076 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -51,8 +51,8 @@ def __str__(self): # class Symbol(eve.GenericNode, LocatedNode, Generic[SymbolT]): # class Symbol(LocatedNode, Generic[SymbolT]): - id: Coerced[SymbolName] # noqa: A003 # shadowing a python builtin - type: Union[SymbolT, ts.DeferredType] # noqa A003 + id: Coerced[SymbolName] # shadowing a python builtin + type: Union[SymbolT, ts.DeferredType] # A003 namespace: dialect_ast_enums.Namespace = dialect_ast_enums.Namespace( dialect_ast_enums.Namespace.LOCAL ) @@ -75,11 +75,11 @@ class Symbol(LocatedNode, Generic[SymbolT]): class Expr(LocatedNode): - type: ts.TypeSpec = ts.DeferredType(constraint=None) # noqa A003 + type: ts.TypeSpec = ts.DeferredType(constraint=None) # A003 class Name(Expr): - id: Coerced[SymbolRef] # noqa: A003 # shadowing a python builtin + id: Coerced[SymbolRef] # shadowing a python builtin class Constant(Expr): @@ -157,7 +157,7 @@ class Stmt(LocatedNode): ... class Starred(Expr): - id: Union[FieldSymbol, TupleSymbol, ScalarSymbol] # noqa: A003 # shadowing a python builtin + id: Union[FieldSymbol, TupleSymbol, ScalarSymbol] # shadowing a python builtin class Assign(Stmt): @@ -198,29 +198,27 @@ def _collect_common_symbols(cls: type[IfStmt], instance: IfStmt) -> None: class FunctionDefinition(LocatedNode, SymbolTableTrait): - id: Coerced[SymbolName] # noqa: A003 # shadowing a python builtin + id: Coerced[SymbolName] # shadowing a python builtin params: list[DataSymbol] body: BlockStmt closure_vars: list[Symbol] - type: Union[ts.FunctionType, ts.DeferredType] = ts.DeferredType( # noqa: A003 - constraint=ts.FunctionType - ) + type: Union[ts.FunctionType, ts.DeferredType] = ts.DeferredType(constraint=ts.FunctionType) class FieldOperator(LocatedNode, SymbolTableTrait): - id: Coerced[SymbolName] # noqa: A003 # shadowing a python builtin + id: Coerced[SymbolName] # shadowing a python builtin definition: FunctionDefinition - type: Union[ts_ffront.FieldOperatorType, ts.DeferredType] = ts.DeferredType( # noqa: A003 + type: Union[ts_ffront.FieldOperatorType, ts.DeferredType] = ts.DeferredType( constraint=ts_ffront.FieldOperatorType ) class ScanOperator(LocatedNode, SymbolTableTrait): - id: Coerced[SymbolName] # noqa: A003 # shadowing a python builtin + id: Coerced[SymbolName] # shadowing a python builtin axis: Constant forward: Constant init: Constant definition: FunctionDefinition # scan pass - type: Union[ts_ffront.ScanOperatorType, ts.DeferredType] = ts.DeferredType( # noqa: A003 + type: Union[ts_ffront.ScanOperatorType, ts.DeferredType] = ts.DeferredType( constraint=ts_ffront.ScanOperatorType ) diff --git a/src/gt4py/next/ffront/foast_introspection.py b/src/gt4py/next/ffront/foast_introspection.py index 404b99d1a0..08efa426ea 100644 --- a/src/gt4py/next/ffront/foast_introspection.py +++ b/src/gt4py/next/ffront/foast_introspection.py @@ -30,23 +30,23 @@ def deduce_stmt_return_kind(node: foast.Stmt) -> StmtReturnKind: Example with ``StmtReturnKind.UNCONDITIONAL_RETURN``:: if cond: - return 1 + return 1 else: - return 2 + return 2 Example with ``StmtReturnKind.CONDITIONAL_RETURN``:: if cond: - return 1 + return 1 else: - result = 2 + result = 2 Example with ``StmtReturnKind.NO_RETURN``:: if cond: - result = 1 + result = 1 else: - result = 2 + result = 2 """ if isinstance(node, foast.IfStmt): return_kinds = ( diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 6af1570fc9..575b8f9399 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -41,7 +41,9 @@ def with_altered_scalar_kind( >>> print(with_altered_scalar_kind(scalar_t, ts.ScalarKind.BOOL)) bool - >>> field_t = ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + >>> field_t = ts.FieldType( + ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ) >>> print(with_altered_scalar_kind(field_t, ts.ScalarKind.FLOAT32)) Field[[I], float32] """ @@ -67,9 +69,14 @@ def construct_tuple_type( Examples: --------- >>> from gt4py.next import Dimension - >>> mask_type = ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL)) + >>> mask_type = ts.FieldType( + ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) + ... ) >>> true_branch_types = [ts.ScalarType(kind=ts.ScalarKind), ts.ScalarType(kind=ts.ScalarKind)] - >>> false_branch_types = [ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)), ts.ScalarType(kind=ts.ScalarKind)] + >>> false_branch_types = [ + ... ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)), + ... ts.ScalarType(kind=ts.ScalarKind), + ... ] >>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type)) [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] """ @@ -104,18 +111,20 @@ def promote_to_mask_type( >>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.ScalarType(kind=dtype)) FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=ScalarType(kind=, shape=None), shape=None)) - >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype)) + >>> promote_to_mask_type( + ... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype) + ... ) FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) - >>> promote_to_mask_type(ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I,J], dtype=dtype)) + >>> promote_to_mask_type( + ... ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I, J], dtype=dtype) + ... ) FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) """ if isinstance(input_type, ts.ScalarType) or not all( item in input_type.dims for item in mask_type.dims ): return_dtype = input_type.dtype if isinstance(input_type, ts.FieldType) else input_type - return type_info.promote( - input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype) - ) # type: ignore + return type_info.promote(input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype)) # type: ignore else: return input_type @@ -233,8 +242,9 @@ class FieldOperatorTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTransla DeferredType(constraint=None) >>> typed_fieldop = FieldOperatorTypeDeduction.apply(untyped_fieldop) - >>> assert typed_fieldop.body.stmts[0].value.type == ts.FieldType(dtype=ts.ScalarType( - ... kind=ts.ScalarKind.FLOAT64), dims=[IDim]) + >>> assert typed_fieldop.body.stmts[0].value.type == ts.FieldType( + ... dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64), dims=[IDim] + ... ) """ @classmethod @@ -797,7 +807,7 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) raise errors.DSLError( node.location, - f"Incompatible field argument in call to '{str(node.func)}'. " + f"Incompatible field argument in call to '{node.func!s}'. " f"Expected a field with dimension '{reduction_dim}', got " f"'{field_dims_str}'.", ) @@ -861,7 +871,7 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: if not type_info.is_integral(arg_1): raise errors.DSLError( node.location, - f"Incompatible argument in call to '{str(node.func)}': " + f"Incompatible argument in call to '{node.func!s}': " f"expected integer for offset field dtype, got '{arg_1.dtype}'. " f"{node.location}", ) @@ -869,7 +879,7 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: if arg_0.source not in arg_1.dims: raise errors.DSLError( node.location, - f"Incompatible argument in call to '{str(node.func)}': " + f"Incompatible argument in call to '{node.func!s}': " f"'{arg_0.source}' not in list of offset field dimensions '{arg_1.dims}'. " f"{node.location}", ) @@ -890,7 +900,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: if not type_info.is_logical(mask_type): raise errors.DSLError( node.location, - f"Incompatible argument in call to '{str(node.func)}': expected " + f"Incompatible argument in call to '{node.func!s}': expected " f"a field with dtype 'bool', got '{mask_type}'.", ) @@ -908,7 +918,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: ): raise errors.DSLError( node.location, - f"Return arguments need to be of same type in '{str(node.func)}', got " + f"Return arguments need to be of same type in '{node.func!s}', got " f"'{node.args[1].type}' and '{node.args[2].type}'.", ) else: @@ -920,7 +930,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: except ValueError as ex: raise errors.DSLError( node.location, - f"Incompatible argument in call to '{str(node.func)}'.", + f"Incompatible argument in call to '{node.func!s}'.", ) from ex return foast.Call( @@ -940,7 +950,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): raise errors.DSLError( node.location, - f"Incompatible broadcast dimension type in '{str(node.func)}': expected " + f"Incompatible broadcast dimension type in '{node.func!s}': expected " f"all broadcast dimensions to be of type 'Dimension'.", ) @@ -949,7 +959,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): raise errors.DSLError( node.location, - f"Incompatible broadcast dimensions in '{str(node.func)}': expected " + f"Incompatible broadcast dimensions in '{node.func!s}': expected " f"broadcast dimension(s) '{set(arg_dims).difference(set(broadcast_dims))}' missing", ) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index dfaefb7211..a5de72093b 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -54,7 +54,7 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): >>> >>> IDim = Dimension("IDim") >>> def fieldop(inp: Field[[IDim], "float64"]): - ... return inp + ... return inp >>> >>> parsed = FieldOperatorParser.apply_to_function(fieldop) >>> lowered = FieldOperatorLowering.apply(parsed) @@ -62,7 +62,7 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): >>> lowered.id SymbolName('fieldop') - >>> lowered.params # doctest: +ELLIPSIS + >>> lowered.params # doctest: +ELLIPSIS [Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] """ diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 0831fc3bb2..1ff3acc205 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -66,14 +66,16 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): If a syntax error is encountered, it will point to the location in the source code. >>> def wrong_syntax(inp: Field[[IDim], int]): - ... for i in [1, 2, 3]: # for is not part of the field operator syntax + ... for i in [1, 2, 3]: # for is not part of the field operator syntax ... tmp = inp ... return tmp >>> - >>> try: # doctest: +ELLIPSIS + >>> try: # doctest: +ELLIPSIS ... FieldOperatorParser.apply_to_function(wrong_syntax) ... except errors.DSLError as err: - ... print(f"Error at [{err.location.line}, {err.location.column}] in {err.location.filename})") + ... print( + ... f"Error at [{err.location.line}, {err.location.column}] in {err.location.filename})" + ... ) Error at [2, 5] in ...func_to_foast.FieldOperatorParser[...]>) """ diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 2fffefe658..61b957d5f7 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -26,8 +26,14 @@ def to_tuples_of_iterator(expr: itir.Expr | str, arg_type: ts.TypeSpec): Supports arbitrary nesting. - >>> print(to_tuples_of_iterator("arg", ts.TupleType(types=[ts.FieldType(dims=[], - ... dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))]))) # doctest: +ELLIPSIS + >>> print( + ... to_tuples_of_iterator( + ... "arg", + ... ts.TupleType( + ... types=[ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))] + ... ), + ... ) + ... ) # doctest: +ELLIPSIS (λ(__toi_...) → {(↑(λ(it) → (·it)[0]))(__toi_...)})(arg) """ param = f"__toi_{eve_utils.content_hash(expr)}" @@ -52,8 +58,14 @@ def to_iterator_of_tuples(expr: itir.Expr | str, arg_type: ts.TypeSpec): Supports arbitrary nesting. - >>> print(to_iterator_of_tuples("arg", ts.TupleType(types=[ts.FieldType(dims=[], - ... dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))]))) # doctest: +ELLIPSIS + >>> print( + ... to_iterator_of_tuples( + ... "arg", + ... ts.TupleType( + ... types=[ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))] + ... ), + ... ) + ... ) # doctest: +ELLIPSIS (λ(__iot_...) → (↑(λ(__iot_el_0) → {·__iot_el_0}))(__iot_...[0]))(arg) """ param = f"__iot_{eve_utils.content_hash(expr)}" @@ -62,7 +74,10 @@ def to_iterator_of_tuples(expr: itir.Expr | str, arg_type: ts.TypeSpec): ti_ffront.promote_scalars_to_zero_dim_field(type_) for type_ in type_info.primitive_constituents(arg_type) ] - assert all(isinstance(type_, ts.FieldType) and type_.dims == type_constituents[0].dims for type_ in type_constituents) # type: ignore[attr-defined] # ensure by assert above + assert all( + isinstance(type_, ts.FieldType) and type_.dims == type_constituents[0].dims + for type_ in type_constituents + ) # type: ignore[attr-defined] # ensure by assert above def fun(_, path): param_name = "__iot_el" @@ -124,16 +139,14 @@ def _process_elements_impl( current_el_type: ts.TypeSpec, ): if isinstance(current_el_type, ts.TupleType): - result = im.make_tuple( - *[ - _process_elements_impl( - process_func, - [im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs], - current_el_type.types[i], - ) - for i in range(len(current_el_type.types)) - ] - ) + result = im.make_tuple(*[ + _process_elements_impl( + process_func, + [im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs], + current_el_type.types[i], + ) + for i in range(len(current_el_type.types)) + ]) elif type_info.contains_local_field(current_el_type): raise NotImplementedError("Processing fields with local dimension is not implemented.") else: diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index e092538c8f..99534b4e61 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -56,19 +56,19 @@ class ProgramLowering( >>> float64 = float >>> IDim = Dimension("IDim") >>> - >>> def fieldop(inp: Field[[IDim], "float64"]) -> Field[[IDim], "float64"]: - ... ... + >>> def fieldop(inp: Field[[IDim], "float64"]) -> Field[[IDim], "float64"]: ... >>> def program(inp: Field[[IDim], "float64"], out: Field[[IDim], "float64"]): - ... fieldop(inp, out=out) + ... fieldop(inp, out=out) >>> >>> parsed = ProgramParser.apply_to_function(program) # doctest: +SKIP >>> fieldop_def = ir.FunctionDefinition( ... id="fieldop", ... params=[ir.Sym(id="inp")], - ... expr=ir.FunCall(fun=ir.SymRef(id="deref"), pos_only_args=[ir.SymRef(id="inp")]) + ... expr=ir.FunCall(fun=ir.SymRef(id="deref"), pos_only_args=[ir.SymRef(id="inp")]), + ... ) # doctest: +SKIP + >>> lowered = ProgramLowering.apply( + ... parsed, [fieldop_def], grid_type=GridType.CARTESIAN ... ) # doctest: +SKIP - >>> lowered = ProgramLowering.apply(parsed, [fieldop_def], - ... grid_type=GridType.CARTESIAN) # doctest: +SKIP >>> type(lowered) # doctest: +SKIP >>> lowered.id # doctest: +SKIP @@ -231,7 +231,6 @@ def _construct_itir_domain_arg( node_domain: Optional[past.Expr], slices: Optional[list[past.Slice]] = None, ) -> itir.FunCall: - assert isinstance(out_field.type, ts.TypeSpec) out_field_types = type_info.primitive_constituents(out_field.type).to_list() out_dims = cast(ts.FieldType, out_field_types[0]).dims diff --git a/src/gt4py/next/ffront/program_ast.py b/src/gt4py/next/ffront/program_ast.py index 4ff8265f70..0b72f4c3a5 100644 --- a/src/gt4py/next/ffront/program_ast.py +++ b/src/gt4py/next/ffront/program_ast.py @@ -29,8 +29,8 @@ class LocatedNode(Node): class Symbol(eve.GenericNode, LocatedNode, Generic[SymbolT]): - id: Coerced[SymbolName] # noqa: A003 - type: Union[SymbolT, ts.DeferredType] # noqa A003 + id: Coerced[SymbolName] + type: Union[SymbolT, ts.DeferredType] # A003 namespace: dialect_ast_enums.Namespace = dialect_ast_enums.Namespace( dialect_ast_enums.Namespace.LOCAL ) @@ -50,7 +50,7 @@ class Symbol(eve.GenericNode, LocatedNode, Generic[SymbolT]): class Expr(LocatedNode): - type: Optional[ts.TypeSpec] = None # noqa A003 + type: Optional[ts.TypeSpec] = None # A003 class BinOp(Expr): @@ -60,7 +60,7 @@ class BinOp(Expr): class Name(Expr): - id: Coerced[SymbolRef] # noqa: A003 + id: Coerced[SymbolRef] class Call(Expr): @@ -97,8 +97,8 @@ class Stmt(LocatedNode): ... class Program(LocatedNode, SymbolTableTrait): - id: Coerced[SymbolName] # noqa: A003 - type: Union[ts_ffront.ProgramType, ts.DeferredType] # noqa A003 + id: Coerced[SymbolName] + type: Union[ts_ffront.ProgramType, ts.DeferredType] # A003 params: list[DataSymbol] body: list[Call] closure_vars: list[Symbol] diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index baf3037d5e..57f4af8d7c 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -106,11 +106,11 @@ class SourceDefinition: >>> def foo(a): ... return a >>> src_def = SourceDefinition.from_function(foo) - >>> print(src_def) # doctest:+ELLIPSIS + >>> print(src_def) # doctest:+ELLIPSIS SourceDefinition(source='def foo(a):...', filename='...', line_offset=0, column_offset=0) >>> source, filename, starting_line = src_def - >>> print(source) # doctest:+ELLIPSIS + >>> print(source) # doctest:+ELLIPSIS def foo(a): return a ... @@ -139,10 +139,10 @@ class SymbolNames: """ params: set[str] - locals: set[str] # noqa: A003 # shadowing a python builtin + locals: set[str] # shadowing a python builtin imported: set[str] nonlocals: set[str] - globals: set[str] # noqa: A003 # shadowing a python builtin + globals: set[str] # shadowing a python builtin @functools.cached_property def all_locals(self) -> set[str]: diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index c25b7dd829..2072c4164a 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -171,7 +171,7 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType -------- >>> _scan_param_promotion( ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64) + ... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64), ... ) FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) """ @@ -180,7 +180,9 @@ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType: assert isinstance(dtype, ts.ScalarType) try: el_type = reduce( - lambda type_, idx: type_.types[idx], path, arg # type: ignore[attr-defined] + lambda type_, idx: type_.types[idx], + path, + arg, # type: ignore[attr-defined] ) return ts.FieldType(dims=type_info.extract_dims(el_type), dtype=dtype) except (IndexError, AttributeError): diff --git a/src/gt4py/next/ffront/type_translation.py b/src/gt4py/next/ffront/type_translation.py index dc3b56176e..4972df357a 100644 --- a/src/gt4py/next/ffront/type_translation.py +++ b/src/gt4py/next/ffront/type_translation.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.type_system.type_translation import ( # noqa: F401 +from gt4py.next.type_system.type_translation import ( from_type_hint as from_type_hint, from_value as from_value, get_scalar_kind as get_scalar_kind, diff --git a/src/gt4py/next/iterator/atlas_utils.py b/src/gt4py/next/iterator/atlas_utils.py index a23a4b5148..50605c1ef1 100644 --- a/src/gt4py/next/iterator/atlas_utils.py +++ b/src/gt4py/next/iterator/atlas_utils.py @@ -47,7 +47,7 @@ def dtype(self): def shape(self): return (self.atlas_connectivity.rows, self.atlas_connectivity.maxcols) - def max(self): # noqa: A003 + def max(self): maximum = -1 for i in range(self.shape[0]): for j in range(self.shape[1]): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 80fc539283..409773de26 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -573,7 +573,7 @@ def execute_shift( def _is_list_of_complete_offsets( - complete_offsets: list[tuple[Any, Any]] + complete_offsets: list[tuple[Any, Any]], ) -> TypeGuard[list[CompleteOffset]]: return all( isinstance(tag, Tag) and isinstance(offset, (int, np.integer)) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 10caecc591..ce45af0870 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -41,7 +41,7 @@ def __hash__(self) -> int: class Sym(Node): # helper - id: Coerced[SymbolName] # noqa: A003 + id: Coerced[SymbolName] # TODO(tehrengruber): Revisit. Using strings is a workaround to avoid coupling with the # type inference. kind: typing.Literal["Iterator", "Value", None] = None @@ -68,7 +68,7 @@ class Expr(Node): ... class Literal(Expr): value: str - type: str # noqa: A003 + type: str @datamodels.validator("type") def _type_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): @@ -89,7 +89,7 @@ class AxisLiteral(Expr): class SymRef(Expr): - id: Coerced[SymbolRef] # noqa: A003 + id: Coerced[SymbolRef] class Lambda(Expr, SymbolTableTrait): @@ -103,7 +103,7 @@ class FunCall(Expr): class FunctionDefinition(Node, SymbolTableTrait): - id: Coerced[SymbolName] # noqa: A003 + id: Coerced[SymbolName] params: List[Sym] expr: Expr @@ -215,7 +215,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib class FencilDefinition(Node, ValidatedSymbolTableTrait): - id: Coerced[SymbolName] # noqa: A003 + id: Coerced[SymbolName] function_definitions: List[FunctionDefinition] params: List[Sym] closures: List[StencilClosure] diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index f6655e9d41..44407164db 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -243,9 +243,7 @@ class let: -------- >>> str(let("a", "b")("a")) # doctest: +ELLIPSIS '(λ(a) → a)(b)' - >>> str(let(("a", 1), - ... ("b", 2) - ... )(plus("a", "b"))) + >>> str(let(("a", 1), ("b", 2))(plus("a", "b"))) '(λ(a, b) → a + b)(1, 2)' """ @@ -301,7 +299,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: """ Make a literal node from a value. - >>> literal_from_value(1.) + >>> literal_from_value(1.0) Literal(value='1.0', type='float64') >>> literal_from_value(1) Literal(value='1', type='int32') diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 786b91bcc5..23f6620f60 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -16,6 +16,7 @@ Inspired by P. Yelland, “A New Approach to Optimal Code Formatting”, 2015 """ + # TODO(tehrengruber): add support for printing the types of itir.Sym, itir.Literal nodes from __future__ import annotations diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 51daffed05..84e4c5562d 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -110,7 +110,7 @@ class Flag(enum.Flag): INLINE_TRIVIAL_LET = enum.auto() @classmethod - def all(self): # noqa: A003 # shadowing a python builtin + def all(self): # shadowing a python builtin return functools.reduce(operator.or_, self.__members__.values()) ignore_tuple_size: bool diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index f9cf272c45..32714232a6 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -244,23 +244,28 @@ def extract_subexpression( >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 >>> new_expr, extracted_subexprs, _ = extract_subexpression( - ... expr, predicate, UIDGenerator(prefix="_subexpr")) + ... expr, predicate, UIDGenerator(prefix="_subexpr") + ... ) >>> print(new_expr) _subexpr_1 + (_subexpr_1 + z) >>> for sym, subexpr in extracted_subexprs.items(): - ... print(f"`{sym}`: `{subexpr}`") + ... print(f"`{sym}`: `{subexpr}`") `_subexpr_1`: `x + y` The order of the extraction can be configured using `deepest_expr_first`. By default, the nodes closer to the root are eliminated first: - >>> expr = im.plus(im.plus(im.plus("x", "y"), im.plus("x", "y")), im.plus(im.plus("x", "y"), im.plus("x", "y"))) - >>> new_expr, extracted_subexprs, ignored_children = extract_subexpression(expr, predicate, - ... UIDGenerator(prefix="_subexpr"), deepest_expr_first=False) + >>> expr = im.plus( + ... im.plus(im.plus("x", "y"), im.plus("x", "y")), + ... im.plus(im.plus("x", "y"), im.plus("x", "y")), + ... ) + >>> new_expr, extracted_subexprs, ignored_children = extract_subexpression( + ... expr, predicate, UIDGenerator(prefix="_subexpr"), deepest_expr_first=False + ... ) >>> print(new_expr) _subexpr_1 + _subexpr_1 >>> for sym, subexpr in extracted_subexprs.items(): - ... print(f"`{sym}`: `{subexpr}`") + ... print(f"`{sym}`: `{subexpr}`") `_subexpr_1`: `x + y + (x + y)` Since `(x+y)` is a child of one of the expressions it is ignored: @@ -270,13 +275,21 @@ def extract_subexpression( Setting `deepest_expr_first` will extract nodes deeper in the tree first: - >>> expr = im.plus(im.plus(im.plus("x", "y"), im.plus("x", "y")), im.plus(im.plus("x", "y"), im.plus("x", "y"))) - >>> new_expr, extracted_subexprs, _ = extract_subexpression(expr, predicate, - ... UIDGenerator(prefix="_subexpr"), once_only=True, deepest_expr_first=True) + >>> expr = im.plus( + ... im.plus(im.plus("x", "y"), im.plus("x", "y")), + ... im.plus(im.plus("x", "y"), im.plus("x", "y")), + ... ) + >>> new_expr, extracted_subexprs, _ = extract_subexpression( + ... expr, + ... predicate, + ... UIDGenerator(prefix="_subexpr"), + ... once_only=True, + ... deepest_expr_first=True, + ... ) >>> print(new_expr) _subexpr_1 + _subexpr_1 + (_subexpr_1 + _subexpr_1) >>> for sym, subexpr in extracted_subexprs.items(): - ... print(f"`{sym}`: `{subexpr}`") + ... print(f"`{sym}`: `{subexpr}`") `_subexpr_1`: `x + y` Note that this requires `once_only` to be set right now. diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 4f4fd053b2..d42ba82574 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -57,7 +57,7 @@ class Temporary(ir.Node): """Iterator IR extension: declaration of a temporary buffer.""" - id: Coerced[eve.SymbolName] # noqa: A003 + id: Coerced[eve.SymbolName] domain: Optional[ir.Expr] = None dtype: Optional[Any] = None @@ -318,7 +318,9 @@ def always_extract_heuristics(_): domain=AUTO_DOMAIN, stencil=stencil, output=im.ref(tmp_sym.id), - inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined] + inputs=[ + closure_param_arg_mapping[param.id] for param in lift_expr.args + ], # type: ignore[attr-defined] location=current_closure.location, ) ) @@ -452,12 +454,10 @@ def from_expr(cls, node: ir.Node): return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above def as_expr(self): - return im.call(self.grid_type)( - *[ - im.call("named_range")(ir.AxisLiteral(value=d), r.start, r.stop) - for d, r in self.ranges.items() - ] - ) + return im.call(self.grid_type)(*[ + im.call("named_range")(ir.AxisLiteral(value=d), r.start, r.stop) + for d, r in self.ranges.items() + ]) def domain_union(domains: list[SymbolicDomain]) -> SymbolicDomain: diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 0b89fe6d98..f6beec3571 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -24,7 +24,7 @@ # TODO(tehrengruber): Reduce complexity of the function by removing the different options here # and introduce a generic predicate argument for the `eligible_params` instead. -def inline_lambda( # noqa: C901 # see todo above +def inline_lambda( # see todo above node: ir.FunCall, opcount_preserving=False, force_inline_lift_args=False, @@ -42,7 +42,7 @@ def inline_lambda( # noqa: C901 # see todo above for i, param in enumerate(node.fun.params): # TODO(tehrengruber): allow inlining more complicated zero-op expressions like - # ignore_shift(...)(it_sym) # noqa: E800 + # ignore_shift(...)(it_sym) if ref_counts[param.id] > 1 and not isinstance( node.args[i], (ir.SymRef, ir.Literal, ir.OffsetLiteral) ): @@ -83,7 +83,7 @@ def inline_lambda( # noqa: C901 # see todo above # TODO(tehrengruber): find a better way of generating new symbols # in `name_map` that don't collide with each other. E.g. this # must still work: - # (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) # noqa: E800 + # (lambda arg, arg_: (lambda arg_: ...)(arg))(a, b) name_map: dict[ir.SymRef, str] = {} def new_name(name): diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 683a57561c..a189538ca1 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -618,24 +618,23 @@ def visit_Sym(self, node: ir.Sym, **kwargs) -> Type: result = TypeVar.fresh() if node.kind: kind = {"Iterator": Iterator(), "Value": Value()}[node.kind] - self.constraints.add( - (Val(kind=kind, current_loc=TypeVar.fresh(), defined_loc=TypeVar.fresh()), result) - ) + self.constraints.add(( + Val(kind=kind, current_loc=TypeVar.fresh(), defined_loc=TypeVar.fresh()), + result, + )) if node.dtype: assert node.dtype is not None dtype: Primitive | List = Primitive(name=node.dtype[0]) if node.dtype[1]: dtype = List(dtype=dtype) - self.constraints.add( - ( - Val( - dtype=dtype, - current_loc=TypeVar.fresh(), - defined_loc=TypeVar.fresh(), - ), - result, - ) - ) + self.constraints.add(( + Val( + dtype=dtype, + current_loc=TypeVar.fresh(), + defined_loc=TypeVar.fresh(), + ), + result, + )) return result def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: @@ -751,18 +750,16 @@ def _visit_neighbors(self, node: ir.FunCall, **kwargs) -> Type: dtype_ = TypeVar.fresh() size = TypeVar.fresh() it = self.visit(node.args[1], **kwargs) - self.constraints.add( - ( - it, - Val( - kind=Iterator(), - dtype=dtype_, - size=size, - current_loc=current_loc_in, - defined_loc=current_loc_out, - ), - ) - ) + self.constraints.add(( + it, + Val( + kind=Iterator(), + dtype=dtype_, + size=size, + current_loc=current_loc_in, + defined_loc=current_loc_out, + ), + )) lst = List( dtype=dtype_, max_length=max_length, @@ -780,16 +777,14 @@ def _visit_cast_(self, node: ir.FunCall, **kwargs) -> Type: size = TypeVar.fresh() - self.constraints.add( - ( - val_arg_type, - Val( - kind=Value(), - dtype=TypeVar.fresh(), - size=size, - ), - ) - ) + self.constraints.add(( + val_arg_type, + Val( + kind=Value(), + dtype=TypeVar.fresh(), + size=size, + ), + )) return Val( kind=Value(), @@ -832,12 +827,10 @@ def _visit_shift(self, node: ir.FunCall, **kwargs) -> Type: def _visit_domain(self, node: ir.FunCall, **kwargs) -> Type: for arg in node.args: - self.constraints.add( - ( - Val(kind=Value(), dtype=NAMED_RANGE_DTYPE, size=Scalar()), - self.visit(arg, **kwargs), - ) - ) + self.constraints.add(( + Val(kind=Value(), dtype=NAMED_RANGE_DTYPE, size=Scalar()), + self.visit(arg, **kwargs), + )) return Val(kind=Value(), dtype=DOMAIN_DTYPE, size=Scalar()) def _visit_cartesian_domain(self, node: ir.FunCall, **kwargs) -> Type: @@ -891,50 +884,45 @@ def visit_StencilClosure( output = self.visit(node.output, **kwargs) output_dtype = TypeVar.fresh() output_loc = TypeVar.fresh() - self.constraints.add( - (domain, Val(kind=Value(), dtype=Primitive(name="domain"), size=Scalar())) - ) - self.constraints.add( - ( - output, - Val( - kind=Iterator(), - dtype=output_dtype, - size=Column(), - defined_loc=output_loc, - ), - ) - ) + self.constraints.add(( + domain, + Val(kind=Value(), dtype=Primitive(name="domain"), size=Scalar()), + )) + self.constraints.add(( + output, + Val( + kind=Iterator(), + dtype=output_dtype, + size=Column(), + defined_loc=output_loc, + ), + )) inputs: list[Type] = self.visit(node.inputs, **kwargs) stencil_params = [] for input_ in inputs: stencil_param = Val(current_loc=output_loc, defined_loc=TypeVar.fresh()) - self.constraints.add( - ( - input_, - Val( - kind=stencil_param.kind, - dtype=stencil_param.dtype, - size=stencil_param.size, - # closure input and stencil param differ in `current_loc` - current_loc=ANYWHERE, - # TODO(tehrengruber): Seems to break for scalars. Use `TypeVar.fresh()`? - defined_loc=stencil_param.defined_loc, - ), - ) - ) + self.constraints.add(( + input_, + Val( + kind=stencil_param.kind, + dtype=stencil_param.dtype, + size=stencil_param.size, + # closure input and stencil param differ in `current_loc` + current_loc=ANYWHERE, + # TODO(tehrengruber): Seems to break for scalars. Use `TypeVar.fresh()`? + defined_loc=stencil_param.defined_loc, + ), + )) stencil_params.append(stencil_param) - self.constraints.add( - ( - stencil, - FunctionType( - args=Tuple.from_elems(*stencil_params), - ret=Val(kind=Value(), dtype=output_dtype, size=Column()), - ), - ) - ) + self.constraints.add(( + stencil, + FunctionType( + args=Tuple.from_elems(*stencil_params), + ret=Val(kind=Value(), dtype=output_dtype, size=Column()), + ), + )) return Closure(output=output, inputs=Tuple.from_elems(*inputs)) def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs): @@ -1005,9 +993,10 @@ def infer_all( ) if reindex: - unified_types, unsatisfiable_constraints = reindex_vars( - (unified_types, unsatisfiable_constraints) - ) + unified_types, unsatisfiable_constraints = reindex_vars(( + unified_types, + unsatisfiable_constraints, + )) result = { id_: unified_type diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index bfb3b0d474..42988ad8db 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -14,7 +14,6 @@ """Python bindings generator for C++ functions.""" - from __future__ import annotations from typing import Any, Sequence, TypeVar, Union diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 2c0511ebf4..2c60b32ee3 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -119,13 +119,11 @@ def visit_LinkDependency(self, dep: LinkDependency): cfg = "" if dep.name == "nanobind": - cfg = "\n".join( - [ - "nanobind_build_library(nanobind-static)", - f"nanobind_compile_options({dep.target})", - f"nanobind_link_options({dep.target})", - ] - ) + cfg = "\n".join([ + "nanobind_build_library(nanobind-static)", + f"nanobind_compile_options({dep.target})", + f"nanobind_link_options({dep.target})", + ]) lnk = f"target_link_libraries({dep.target} PUBLIC {lib_name})" return cfg + "\n" + lnk diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index 810952d0ef..638648f7b3 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -14,7 +14,6 @@ """Caching for compiled backend artifacts.""" - import hashlib import pathlib import tempfile diff --git a/src/gt4py/next/otf/compilation/common.py b/src/gt4py/next/otf/compilation/common.py index 37950f8186..784295a55e 100644 --- a/src/gt4py/next/otf/compilation/common.py +++ b/src/gt4py/next/otf/compilation/common.py @@ -14,7 +14,6 @@ """Shared build system functionality.""" - from __future__ import annotations import importlib diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index bd7f59e7aa..e9c7f49c26 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -122,8 +122,14 @@ def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryD Examples: --------- - >>> libs_a = (interface.LibraryDependency("foo", "1.2.3"), interface.LibraryDependency("common", "1.0.0")) - >>> libs_b = (interface.LibraryDependency("common", "1.0.0"), interface.LibraryDependency("bar", "1.2.3")) + >>> libs_a = ( + ... interface.LibraryDependency("foo", "1.2.3"), + ... interface.LibraryDependency("common", "1.0.0"), + ... ) + >>> libs_b = ( + ... interface.LibraryDependency("common", "1.0.0"), + ... interface.LibraryDependency("bar", "1.2.3"), + ... ) >>> _unique_libs(*libs_a, *libs_b) (LibraryDependency(name='foo', version='1.2.3'), LibraryDependency(name='common', version='1.0.0'), LibraryDependency(name='bar', version='1.2.3')) """ diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 4bdb4bbb41..2ea2fe9254 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -39,10 +39,10 @@ def make_step(function: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[StartT --------- >>> @make_step ... def times_two(x: int) -> int: - ... return x * 2 + ... return x * 2 >>> def stringify(x: int) -> str: - ... return str(x) + ... return str(x) >>> # create a workflow int -> int -> str >>> times_two.chain(stringify)(3) @@ -102,25 +102,21 @@ class NamedStepSequence( >>> import dataclasses >>> def parse(x: str) -> int: - ... return int(x) + ... return int(x) >>> def plus_half(x: int) -> float: - ... return x + 0.5 + ... return x + 0.5 >>> def stringify(x: float) -> str: - ... return str(x) + ... return str(x) >>> @dataclasses.dataclass(frozen=True) ... class ParseOpPrint(NamedStepSequence[str, str]): - ... parse: Workflow[str, int] - ... op: Workflow[int, float] - ... print: Workflow[float, str] + ... parse: Workflow[str, int] + ... op: Workflow[int, float] + ... print: Workflow[float, str] - >>> pop = ParseOpPrint( - ... parse=parse, - ... op=plus_half, - ... print=stringify - ... ) + >>> pop = ParseOpPrint(parse=parse, op=plus_half, print=stringify) >>> pop.step_order ['parse', 'op', 'print'] @@ -129,7 +125,7 @@ class NamedStepSequence( '73.5' >>> def plus_tenth(x: int) -> float: - ... return x + 0.1 + ... return x + 0.1 >>> pop.replace(op=plus_tenth)(73) @@ -169,13 +165,13 @@ class StepSequence(ChainableWorkflowMixin[StartT, EndT]): Examples: --------- >>> def plus_one(x: int) -> int: - ... return x + 1 + ... return x + 1 >>> def plus_half(x: int) -> float: - ... return x + 0.5 + ... return x + 0.5 >>> def stringify(x: float) -> str: - ... return str(x) + ... return str(x) >>> StepSequence.start(plus_one).chain(plus_half).chain(stringify)(73) '74.5' @@ -222,8 +218,8 @@ class CachedStep( Examples: --------- >>> def heavy_computation(x: int) -> int: - ... print("This might take a while...") - ... return x + ... print("This might take a while...") + ... return x >>> cached_step = CachedStep(step=heavy_computation) @@ -241,9 +237,7 @@ class CachedStep( """ step: Workflow[StartT, EndT] - hash_function: Callable[[StartT], HashT] = dataclasses.field( - default=hash - ) # type: ignore[assignment] + hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] _cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py index a62f50fc44..9cdcc381d9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py @@ -50,6 +50,6 @@ class ReturnStmt(Stmt): class ImperativeFunctionDefinition(Node, SymbolTableTrait): - id: Coerced[SymbolName] # noqa: A003 + id: Coerced[SymbolName] params: List[Sym] fun: List[Stmt] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 4bb54aad49..0b465047cb 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -48,7 +48,7 @@ class CastExpr(Expr): class Literal(Expr): value: str - type: str # noqa: A003 + type: str class IntegralConstant(Expr): @@ -70,13 +70,13 @@ class FunCall(Expr): class FunctionDefinition(Node, SymbolTableTrait): - id: Coerced[SymbolName] # noqa: A003 + id: Coerced[SymbolName] params: list[Sym] expr: Expr class ScanPassDefinition(Node, SymbolTableTrait): - id: Coerced[SymbolName] # noqa: A003 + id: Coerced[SymbolName] params: list[Sym] expr: Expr forward: bool @@ -128,7 +128,7 @@ class ScanExecution(Node): class TemporaryAllocation(Node): - id: SymbolName # noqa: A003 + id: SymbolName dtype: str domain: Union[SymRef, CartesianDomain, UnstructuredDomain] @@ -160,7 +160,7 @@ class TagDefinition(Node): class FencilDefinition(Node, ValidatedSymbolTableTrait): - id: SymbolName # noqa: A003 + id: SymbolName params: list[Sym] function_definitions: list[ Union[FunctionDefinition, ScanPassDefinition, ImperativeFunctionDefinition] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py index cb9aeffb90..c70ce00dba 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py @@ -22,11 +22,11 @@ class Node(eve.Node): class Sym(Node): # helper - id: Coerced[SymbolName] # noqa: A003 + id: Coerced[SymbolName] class Expr(Node): ... class SymRef(Expr): - id: Coerced[SymbolRef] # noqa: A003 + id: Coerced[SymbolRef] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index 74fbbfc93f..361cc3c9d6 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -276,9 +276,9 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): args = [self.visit(arg, **kwargs) for arg in node.args] for param, arg in zip(params, args): if param.id in self.sym_table: - kwargs["localized_symbols"][ - param.id - ] = f"{param.id}_{self.uids.sequential_id()}_local" + kwargs["localized_symbols"][param.id] = ( + f"{param.id}_{self.uids.sequential_id()}_local" + ) self.imp_list_ir.append( InitStmt( lhs=gtfn_ir_common.Sym(id=kwargs["localized_symbols"][param.id]), diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 46861197fe..4c3af5129b 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -192,8 +192,8 @@ def _preprocess_program( lift_mode = runtime_lift_mode or self.lift_mode if runtime_lift_mode and runtime_lift_mode != self.lift_mode: warnings.warn( - f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " - f"overriden to be {str(runtime_lift_mode)} at runtime." + f"GTFN Backend was configured for LiftMode `{self.lift_mode!s}`, but " + f"overriden to be {runtime_lift_mode!s} at runtime." ) if not self.enable_itir_transforms: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 842080f8ae..98eff62d60 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -489,7 +489,7 @@ def visit_StencilClosure( @staticmethod def _merge_scans( - executions: list[Union[StencilExecution, ScanExecution]] + executions: list[Union[StencilExecution, ScanExecution]], ) -> list[Union[StencilExecution, ScanExecution]]: def merge(a: ScanExecution, b: ScanExecution) -> ScanExecution: assert a.backend == b.backend diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index 0c280202b8..37921cdd35 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -24,6 +24,7 @@ For more information refer to ``gt4py/docs/functional/architecture/007-Program-Processors.md`` """ + from __future__ import annotations import functools @@ -237,7 +238,8 @@ class ProgramBackend( def is_program_backend(obj: Callable) -> TypeGuard[ProgramBackend]: return is_processor_kind( - obj, ProgramExecutor # type: ignore[type-abstract] # ProgramExecutor is abstract + obj, + ProgramExecutor, # type: ignore[type-abstract] # ProgramExecutor is abstract ) and next_allocators.is_field_allocator_factory(obj) @@ -245,5 +247,6 @@ def is_program_backend_for( obj: Callable, device: core_defs.DeviceTypeT ) -> TypeGuard[ProgramBackend[core_defs.DeviceTypeT]]: return is_processor_kind( - obj, ProgramExecutor # type: ignore[type-abstract] # ProgramExecutor is abstract + obj, + ProgramExecutor, # type: ignore[type-abstract] # ProgramExecutor is abstract ) and next_allocators.is_field_allocator_factory_for(obj, device) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 1263cff502..75b4b3eda8 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -218,7 +218,7 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, check_args: bool = False, **kwargs) -> Args: sdfg: The SDFG for which we want to get the arguments. - """ # noqa: D401 + """ offset_provider = kwargs["offset_provider"] on_gpu = kwargs.get("on_gpu", False) @@ -298,10 +298,13 @@ def build_sdfg_from_itir( for nested_sdfg in sdfg.all_sdfgs_recursive(): if not nested_sdfg.debuginfo: - _, frameinfo = warnings.warn( - f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg." - ), getframeinfo( - currentframe() # type: ignore + _, frameinfo = ( + warnings.warn( + f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg." + ), + getframeinfo( + currentframe() # type: ignore + ), ) nested_sdfg.debuginfo = dace.dtypes.DebugInfo( start_line=frameinfo.lineno, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 6ab371bf2b..d67d014b9d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -1158,9 +1158,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: map_ranges = { index_name: f"0:{offset_provider.max_neighbors}", } - src_subset = ",".join( - [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] - ) + src_subset = ",".join([ + f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims + ]) self.context.state.add_mapped_tasklet( "deref", map_ranges, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 0c3fd741d5..567b8b9356 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -180,7 +180,7 @@ def add_mapped_nested_sdfg( def unique_name(prefix): - unique_id = getattr(unique_name, "_unique_id", 0) # noqa: B010 # static variable + unique_id = getattr(unique_name, "_unique_id", 0) # static variable setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 # static variable return f"{prefix}_{unique_id}" @@ -198,7 +198,7 @@ def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dac def flatten_list(node_list: list[Any]) -> list[Any]: return list( - itertools.chain.from_iterable( - [flatten_list(e) if e.__class__ == list else [e] for e in node_list] - ) + itertools.chain.from_iterable([ + flatten_list(e) if e.__class__ == list else [e] for e in node_list + ]) ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 4a65f6d049..e490e83ddb 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -103,16 +103,14 @@ def extract_connectivity_args( def compilation_hash(otf_closure: stages.ProgramCall) -> int: """Given closure compute a hash uniquely determining if we need to recompile.""" offset_provider = otf_closure.kwargs["offset_provider"] - return hash( - ( - otf_closure.program, - # As the frontend types contain lists they are not hashable. As a workaround we just - # use content_hash here. - content_hash(tuple(from_value(arg) for arg in otf_closure.args)), - id(offset_provider) if offset_provider else None, - otf_closure.kwargs.get("column_axis", None), - ) - ) + return hash(( + otf_closure.program, + # As the frontend types contain lists they are not hashable. As a workaround we just + # use content_hash here. + content_hash(tuple(from_value(arg) for arg in otf_closure.args)), + id(offset_provider) if offset_provider else None, + otf_closure.kwargs.get("column_axis", None), + )) class GTFNCompileWorkflowFactory(factory.Factory): diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 5cfb901ff1..fd11a421c0 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -29,7 +29,7 @@ def _number_to_ordinal_number(number: int) -> str: Convert number into ordinal number. >>> for i in range(0, 5): - ... print(_number_to_ordinal_number(i)) + ... print(_number_to_ordinal_number(i)) 0th 1st 2nd @@ -150,22 +150,24 @@ def apply_to_primitive_constituents( >>> int_type = ts.ScalarType(kind=ts.ScalarKind.INT64) >>> tuple_type = ts.TupleType(types=[int_type, int_type]) - >>> print(apply_to_primitive_constituents(tuple_type, lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type))) + >>> print( + ... apply_to_primitive_constituents( + ... tuple_type, lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type) + ... ) + ... ) tuple[Field[[], int64], Field[[], int64]] """ if isinstance(symbol_type, ts.TupleType): - return tuple_constructor( - *[ - apply_to_primitive_constituents( - el, - fun, - _path=(*_path, i), - with_path_arg=with_path_arg, - tuple_constructor=tuple_constructor, - ) - for i, el in enumerate(symbol_type.types) - ] - ) + return tuple_constructor(*[ + apply_to_primitive_constituents( + el, + fun, + _path=(*_path, i), + with_path_arg=with_path_arg, + tuple_constructor=tuple_constructor, + ) + for i, el in enumerate(symbol_type.types) + ]) if with_path_arg: return fun(symbol_type, _path) # type: ignore[call-arg] # mypy not aware of `with_path_arg` else: @@ -298,7 +300,9 @@ def is_type_or_tuple_of_type(type_: ts.TypeSpec, expected_type: type | tuple) -> >>> field_type = ts.FieldType(dims=[], dtype=scalar_type) >>> is_type_or_tuple_of_type(field_type, ts.FieldType) True - >>> is_type_or_tuple_of_type(ts.TupleType(types=[scalar_type, field_type]), (ts.ScalarType, ts.FieldType)) + >>> is_type_or_tuple_of_type( + ... ts.TupleType(types=[scalar_type, field_type]), (ts.ScalarType, ts.FieldType) + ... ) True >>> is_type_or_tuple_of_type(scalar_type, ts.FieldType) False @@ -318,7 +322,9 @@ def is_tuple_of_type(type_: ts.TypeSpec, expected_type: type | tuple) -> TypeGua >>> field_type = ts.FieldType(dims=[], dtype=scalar_type) >>> is_tuple_of_type(field_type, ts.FieldType) False - >>> is_tuple_of_type(ts.TupleType(types=[scalar_type, field_type]), (ts.ScalarType, ts.FieldType)) + >>> is_tuple_of_type( + ... ts.TupleType(types=[scalar_type, field_type]), (ts.ScalarType, ts.FieldType) + ... ) True >>> is_tuple_of_type(ts.TupleType(types=[scalar_type]), ts.FieldType) False @@ -381,38 +387,37 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: Examples: --------- >>> is_concretizable( - ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... to_type=ts.ScalarType(kind=ts.ScalarKind.INT64) + ... ts.ScalarType(kind=ts.ScalarKind.INT64), to_type=ts.ScalarType(kind=ts.ScalarKind.INT64) ... ) True >>> is_concretizable( ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... to_type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... to_type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ... ) False >>> is_concretizable( ... ts.DeferredType(constraint=None), - ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]) + ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]), ... ) True >>> is_concretizable( ... ts.DeferredType(constraint=ts.DataType), - ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]) + ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]), ... ) True >>> is_concretizable( ... ts.DeferredType(constraint=ts.OffsetType), - ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]) + ... to_type=ts.FieldType(dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL), dims=[]), ... ) False >>> is_concretizable( ... ts.DeferredType(constraint=ts.TypeSpec), - ... to_type=ts.DeferredType(constraint=ts.ScalarType) + ... to_type=ts.DeferredType(constraint=ts.ScalarType), ... ) True @@ -437,17 +442,14 @@ def promote(*types: ts.FieldType | ts.ScalarType) -> ts.FieldType | ts.ScalarTyp >>> dtype = ts.ScalarType(kind=ts.ScalarKind.INT64) >>> I, J, K = (common.Dimension(value=dim) for dim in ["I", "J", "K"]) >>> promoted: ts.FieldType = promote( - ... ts.FieldType(dims=[I, J], dtype=dtype), - ... ts.FieldType(dims=[I, J, K], dtype=dtype), - ... dtype + ... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[I, J, K], dtype=dtype), dtype ... ) >>> promoted.dims == [I, J, K] and promoted.dtype == dtype True >>> promote( - ... ts.FieldType(dims=[I, J], dtype=dtype), - ... ts.FieldType(dims=[K], dtype=dtype) - ... ) # doctest: +ELLIPSIS + ... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype) + ... ) # doctest: +ELLIPSIS Traceback (most recent call last): ... ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. @@ -636,7 +638,7 @@ def function_signature_incompatibilities( @function_signature_incompatibilities.register -def function_signature_incompatibilities_func( # noqa: C901 +def function_signature_incompatibilities_func( func_type: ts.FunctionType, args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec], @@ -700,7 +702,12 @@ def function_signature_incompatibilities_field( # TODO: This code does not handle ellipses for dimensions. Fix it. assert field_type.dims is not ... if field_type.dims and source_dim not in field_type.dims: - yield f"Incompatible offset can not shift field defined on " f"{', '.join([dim.value for dim in field_type.dims])} from " f"{source_dim.value} to target dim(s): " f"{', '.join([dim.value for dim in target_dims])}" + yield ( + f"Incompatible offset can not shift field defined on " + f"{', '.join([dim.value for dim in field_type.dims])} from " + f"{source_dim.value} to target dim(s): " + f"{', '.join([dim.value for dim in target_dims])}" + ) def accepts_args( @@ -724,7 +731,7 @@ def accepts_args( ... pos_only_args=[bool_type], ... pos_or_kw_args={"foo": bool_type}, ... kw_only_args={}, - ... returns=ts.VoidType() + ... returns=ts.VoidType(), ... ) >>> accepts_args(func_type, with_args=[bool_type], with_kwargs={"foo": bool_type}) True diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index ec459906e0..21932afd70 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -21,10 +21,10 @@ class RecursionGuard: Context manager to guard against inifinite recursion. >>> def foo(i): - ... with RecursionGuard(i): - ... if i % 2 == 0: - ... foo(i) - ... return i + ... with RecursionGuard(i): + ... if i % 2 == 0: + ... foo(i) + ... return i >>> foo(3) 3 >>> foo(2) # doctest:+ELLIPSIS diff --git a/src/gt4py/storage/__init__.py b/src/gt4py/storage/__init__.py index c62c9998e9..d23d880c58 100644 --- a/src/gt4py/storage/__init__.py +++ b/src/gt4py/storage/__init__.py @@ -18,7 +18,7 @@ from . import cartesian from .cartesian import layout -from .cartesian.interface import empty, from_array, full, ones, zeros # noqa: F401 +from .cartesian.interface import empty, from_array, full, ones, zeros from .cartesian.layout import from_name, register diff --git a/src/gt4py/storage/cartesian/layout.py b/src/gt4py/storage/cartesian/layout.py index 26e34e35d6..65b1967448 100644 --- a/src/gt4py/storage/cartesian/layout.py +++ b/src/gt4py/storage/cartesian/layout.py @@ -73,7 +73,7 @@ def check_layout(layout_map, strides): def layout_maker_factory( - base_layout: Tuple[int, ...] + base_layout: Tuple[int, ...], ) -> Callable[[Tuple[str, ...]], Tuple[int, ...]]: def layout_maker(dimensions: Tuple[str, ...]) -> Tuple[int, ...]: mask = [dim in dimensions for dim in "IJK"] diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py index 6b8c02e41c..5814daa495 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_exec_info.py @@ -49,7 +49,10 @@ def advection_def( @staticmethod def diffusion_def( - in_phi: gtscript.Field[float], out_phi: gtscript.Field[float], *, alpha: float # type: ignore + in_phi: gtscript.Field[float], + out_phi: gtscript.Field[float], + *, + alpha: float, # type: ignore ): with computation(PARALLEL), interval(...): # type: ignore # noqa lap1 = ( diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 4ac239fdd2..79056c2914 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -142,7 +142,11 @@ def native_functions(field_a: Field3D, field_b: Field3D): field_b = ( trunc_res if isfinite(trunc_res) - else field_a if isinf(trunc_res) else field_b if isnan(trunc_res) else 0.0 + else field_a + if isinf(trunc_res) + else field_b + if isnan(trunc_res) + else 0.0 ) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index 6110e29cdb..10019343ab 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -165,9 +165,11 @@ def definition(field_a, field_b, field_c, field_out, *, weight, alpha_factor): factor = alpha_factor else: factor = 1.0 - field_out = factor * field_a[ # noqa: F841 # Local name is assigned to but never used - 0, 0, 0 - ] - (1 - factor) * (field_b[0, 0, 0] - weight * field_c[0, 0, 0]) + field_out = ( + factor + * field_a[0, 0, 0] # noqa: F841 # Local name is assigned to but never used + - (1 - factor) * (field_b[0, 0, 0] - weight * field_c[0, 0, 0]) + ) def validation( field_a, field_b, field_c, field_out, *, weight, alpha_factor, domain, origin, **kwargs @@ -225,9 +227,10 @@ def definition(u, diffusion, *, weight): laplacian = 4.0 * u[0, 0, 0] - (u[1, 0, 0] + u[-1, 0, 0] + u[0, 1, 0] + u[0, -1, 0]) flux_i = laplacian[1, 0, 0] - laplacian[0, 0, 0] flux_j = laplacian[0, 1, 0] - laplacian[0, 0, 0] - diffusion = u[ # noqa: F841 # Local name is assigned to but never used - 0, 0, 0 - ] - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + diffusion = ( + u[0, 0, 0] # noqa: F841 # Local name is assigned to but never used + - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + ) def validation(u, diffusion, *, weight, domain, origin, **kwargs): laplacian = 4.0 * u[1:-1, 1:-1, :] - ( @@ -290,9 +293,10 @@ def definition(u, diffusion, *, weight): with computation(PARALLEL), interval(...): laplacian = lap_op(u=u) flux_i, flux_j = fwd_diff(field=laplacian) - diffusion = u[ # noqa: F841 # Local name is assigned to but never used - 0, 0, 0 - ] - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + diffusion = ( + u[0, 0, 0] # noqa: F841 # Local name is assigned to but never used + - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + ) def validation(u, diffusion, *, weight, domain, origin, **kwargs): laplacian = 4.0 * u[1:-1, 1:-1, :] - ( @@ -330,9 +334,10 @@ def definition(u, diffusion, *, weight): flux_j = fwd_diff_op_y(field=laplacian) else: flux_i, flux_j = fwd_diff_op_xy(field=laplacian) - diffusion = u[ # noqa: F841 # Local name is assigned to but never used - 0, 0, 0 - ] - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + diffusion = ( + u[0, 0, 0] # noqa: F841 # Local name is assigned to but never used + - weight * (flux_i[0, 0, 0] - flux_i[-1, 0, 0] + flux_j[0, 0, 0] - flux_j[0, -1, 0]) + ) def validation(u, diffusion, *, weight, domain, origin, **kwargs): laplacian = 4.0 * u[1:-1, 1:-1, :] - ( @@ -792,9 +797,7 @@ class TestVariableKRead(gt_testing.StencilTestSuite): def definition(field_in, field_out, index): with computation(PARALLEL), interval(1, None): - field_out = field_in[ # noqa: F841 # Local name is assigned to but never used - 0, 0, index - ] + field_out = field_in[0, 0, index] # noqa: F841 # Local name is assigned to but never used def validation(field_in, field_out, index, *, domain, origin): field_out[:, :, 1:] = field_in[:, :, (np.arange(field_in.shape[-1]) + index)[1:]] diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py index 83195a898a..934597f4e5 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_defir_to_gtir.py @@ -45,7 +45,8 @@ def defir_to_gtir(): def test_stencil_definition( - defir_to_gtir, ijk_domain # noqa: F811 [redefinition, reason: fixture] + defir_to_gtir, + ijk_domain, # noqa: F811 [redefinition, reason: fixture] ): stencil_definition = ( TDefinition(name="definition", domain=ijk_domain, fields=["a", "b"]) diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 00b91eb2cf..f7d46cc8b2 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -1202,16 +1202,14 @@ def definition_func(inout_field: gtscript.Field[float]): @pytest.mark.parametrize( "id_case,import_line", list( - enumerate( - [ - "import gt4py", - "from externals import EXTERNAL", - "from gt4py.cartesian import __gtscript__", - "from gt4py.cartesian import __externals__", - "from gt4py.cartesian.gtscript import computation", - "from gt4py.cartesian.externals import EXTERNAL", - ] - ) + enumerate([ + "import gt4py", + "from externals import EXTERNAL", + "from gt4py.cartesian import __gtscript__", + "from gt4py.cartesian import __externals__", + "from gt4py.cartesian.gtscript import computation", + "from gt4py.cartesian.externals import EXTERNAL", + ]) ), ) def test_wrong_imports(self, id_case, import_line): @@ -1240,19 +1238,17 @@ class TestDTypes: @pytest.mark.parametrize( "id_case,test_dtype", list( - enumerate( - [ - bool, - np.bool_, - int, - np.int32, - np.int64, - float, - np.float32, - np.float64, - np.dtype((np.float32, (3,))), - ] - ) + enumerate([ + bool, + np.bool_, + int, + np.int32, + np.int64, + float, + np.float32, + np.float64, + np.dtype((np.float32, (3,))), + ]) ), ) def test_all_legal_dtypes_instance(self, id_case, test_dtype): diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index 8cfff12df4..1a51cad736 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -312,13 +312,15 @@ def test_symbolref_validation_for_valid_tree(): SymbolTableRootNode( nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")], ) - SymbolTableRootNode( # noqa: B018 - nodes=[ - SymbolChildNode(name="foo"), - SymbolRefChildNode(name="foo"), - SymbolRefChildNode(name="foo"), - ], - ), + ( + SymbolTableRootNode( # noqa: B018 + nodes=[ + SymbolChildNode(name="foo"), + SymbolRefChildNode(name="foo"), + SymbolRefChildNode(name="foo"), + ], + ), + ) SymbolTableRootNode( nodes=[ SymbolChildNode(name="outer_scope"), diff --git a/tests/conftest.py b/tests/conftest.py index 29ce8e32e9..53752a6cd6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,5 @@ """Global configuration of pytest for collecting and running tests.""" - # Ignore hidden folders and disabled tests collect_ignore_glob = [".*", "_disabled*"] diff --git a/tests/eve_tests/conftest.py b/tests/eve_tests/conftest.py index 7db9737905..d49b46da10 100644 --- a/tests/eve_tests/conftest.py +++ b/tests/eve_tests/conftest.py @@ -14,7 +14,6 @@ """Global configuration of test generation and execution with pytest.""" - import pytest from . import definitions diff --git a/tests/eve_tests/unit_tests/test_extended_typing.py b/tests/eve_tests/unit_tests/test_extended_typing.py index d90a577bf9..ec2d5a8fe8 100644 --- a/tests/eve_tests/unit_tests/test_extended_typing.py +++ b/tests/eve_tests/unit_tests/test_extended_typing.py @@ -326,12 +326,10 @@ def test_is_actual_wrong_type(t): (List[int], type(List[int])), ] if sys.version_info >= (3, 9): - ACTUAL_TYPE_SAMPLES.extend( - [ - (tuple[int, float], types.GenericAlias), # type: ignore[misc] # ignore false positive bug: https://github.com/python/mypy/issues/11098 - (list[int], types.GenericAlias), - ] - ) + ACTUAL_TYPE_SAMPLES.extend([ + (tuple[int, float], types.GenericAlias), # type: ignore[misc] # ignore false positive bug: https://github.com/python/mypy/issues/11098 + (list[int], types.GenericAlias), + ]) @pytest.mark.parametrize(["instance", "expected"], ACTUAL_TYPE_SAMPLES) @@ -515,9 +513,7 @@ class MissingRef: ... globalns={"Annotated": Annotated, "Callable": Callable}, localns={"MissingRef": MissingRef}, ) - ) == Callable[ - [int], MissingRef - ] or ( # some patch versions of cpython3.9 show weird behaviors + ) == Callable[[int], MissingRef] or ( # some patch versions of cpython3.9 show weird behaviors sys.version_info >= (3, 9) and sys.version_info < (3, 10) and (ref == Callable[[Annotated[int, "Foo"]], MissingRef]) diff --git a/tests/eve_tests/unit_tests/test_type_validation.py b/tests/eve_tests/unit_tests/test_type_validation.py index d9977f0d3a..c60c134d77 100644 --- a/tests/eve_tests/unit_tests/test_type_validation.py +++ b/tests/eve_tests/unit_tests/test_type_validation.py @@ -153,15 +153,13 @@ class SampleDataClass: class SampleSlottedDataClass: b: float - SAMPLE_TYPE_DEFINITIONS.append( - ( - SampleSlottedDataClass, - [SampleSlottedDataClass(1.0), SampleSlottedDataClass(1)], - [object(), float(1.2), int(1), "1.2", SampleSlottedDataClass], - None, - None, - ) - ) + SAMPLE_TYPE_DEFINITIONS.append(( + SampleSlottedDataClass, + [SampleSlottedDataClass(1.0), SampleSlottedDataClass(1)], + [object(), float(1.2), int(1), "1.2", SampleSlottedDataClass], + None, + None, + )) @pytest.mark.parametrize("validator", VALIDATORS) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index aca601d74e..00a193be0b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -48,11 +48,14 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append(next_tests.definitions.OptionalProgramBackendId.DACE_CPU) - OPTIONAL_PROCESSORS.append( - pytest.param( - next_tests.definitions.OptionalProgramBackendId.DACE_GPU, marks=pytest.mark.requires_gpu - ) - ), + ( + OPTIONAL_PROCESSORS.append( + pytest.param( + next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + marks=pytest.mark.requires_gpu, + ) + ), + ) @pytest.fixture( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 5ded38abdb..5050008ef1 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -779,7 +779,7 @@ def test_scan_nested_tuple_output(forward, cartesian_case): @gtx.scan_operator(axis=KDim, forward=forward, init=init) def simple_scan_operator( - carry: tuple[int32, tuple[int32, int32]] + carry: tuple[int32, tuple[int32, int32]], ) -> tuple[int32, tuple[int32, int32]]: return (carry[0] + 1, (carry[1][0] + 1, carry[1][1] + 1)) @@ -812,12 +812,10 @@ def test_scan_nested_tuple_input(cartesian_case): def prev_levels_iterator(i): return range(i + 1) - expected = np.asarray( - [ - reduce(lambda prev, i: prev + inp1_np[i] + inp2_np[i], prev_levels_iterator(i), init) - for i in range(k_size) - ] - ) + expected = np.asarray([ + reduce(lambda prev, i: prev + inp1_np[i] + inp2_np[i], prev_levels_iterator(i), init) + for i in range(k_size) + ]) @gtx.scan_operator(axis=KDim, forward=True, init=init) def simple_scan_operator(carry: float, a: tuple[float, float]) -> float: @@ -832,12 +830,10 @@ def test_scan_different_domain_in_tuple(cartesian_case): i_size = cartesian_case.default_sizes[IDim] k_size = cartesian_case.default_sizes[KDim] - inp1_np = np.ones( - ( - i_size + 1, - k_size, - ) - ) # i_size bigger than in the other argument + inp1_np = np.ones(( + i_size + 1, + k_size, + )) # i_size bigger than in the other argument inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) inp1 = cartesian_case.as_field([IDim, KDim], inp1_np) inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) @@ -846,16 +842,14 @@ def test_scan_different_domain_in_tuple(cartesian_case): def prev_levels_iterator(i): return range(i + 1) - expected = np.asarray( - [ - reduce( - lambda prev, k: prev + inp1_np[:-1, k] + inp2_np[:, k], - prev_levels_iterator(k), - init, - ) - for k in range(k_size) - ] - ).transpose() + expected = np.asarray([ + reduce( + lambda prev, k: prev + inp1_np[:-1, k] + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ]).transpose() @gtx.scan_operator(axis=KDim, forward=True, init=init) def scan_op(carry: float, a: tuple[float, float]) -> float: @@ -883,16 +877,14 @@ def test_scan_tuple_field_scalar_mixed(cartesian_case): def prev_levels_iterator(i): return range(i + 1) - expected = np.asarray( - [ - reduce( - lambda prev, k: prev + 1.0 + inp2_np[:, k], - prev_levels_iterator(k), - init, - ) - for k in range(k_size) - ] - ).transpose() + expected = np.asarray([ + reduce( + lambda prev, k: prev + 1.0 + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ]).transpose() @gtx.scan_operator(axis=KDim, forward=True, init=init) def scan_op(carry: float, a: tuple[float, float]) -> float: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 583246f50f..19758d1f09 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -95,14 +95,14 @@ def reduction_e_field(edge_f: cases.EField) -> cases.VField: @gtx.field_operator def reduction_ek_field( - edge_f: common.Field[[Edge, KDim], np.int32] + edge_f: common.Field[[Edge, KDim], np.int32], ) -> common.Field[[Vertex, KDim], np.int32]: return neighbor_sum(edge_f(V2E), axis=V2EDim) @gtx.field_operator def reduction_ke_field( - edge_f: common.Field[[KDim, Edge], np.int32] + edge_f: common.Field[[KDim, Edge], np.int32], ) -> common.Field[[KDim, Vertex], np.int32]: return neighbor_sum(edge_f(V2E), axis=V2EDim) @@ -202,7 +202,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): fencil, ref=lambda edge_f: 3 * np.sum( - -edge_f[v2e_table] ** 2 * 2, + -(edge_f[v2e_table] ** 2) * 2, axis=1, initial=0, where=v2e_table != common._DEFAULT_SKIP_VALUE, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index df0009d0d4..19cbbb4ba2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -72,9 +72,9 @@ def shift_by_one(in_field: cases.IFloatField) -> cases.IFloatField: def shift_by_one_program(in_field: cases.IFloatField, out_field: cases.IFloatField): shift_by_one(in_field, out=out_field[:-1]) - in_field = cases.allocate(cartesian_case, shift_by_one_program, "in_field").extend( - {IDim: (0, 1)} - )() + in_field = cases.allocate(cartesian_case, shift_by_one_program, "in_field").extend({ + IDim: (0, 1) + })() out_field = cases.allocate(cartesian_case, shift_by_one_program, "out_field")() cases.verify( @@ -107,12 +107,10 @@ def test_copy_restricted_execution(cartesian_case, copy_restrict_program_def): cases.verify_with_default_data( cartesian_case, copy_restrict_program, - ref=lambda in_field: np.array( - [ - in_field[i] if i in range(1, 2) else 0 - for i in range(0, cartesian_case.default_sizes[IDim]) - ] - ), + ref=lambda in_field: np.array([ + in_field[i] if i in range(1, 2) else 0 + for i in range(0, cartesian_case.default_sizes[IDim]) + ]), ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py index 03e1af27dd..998b351255 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py @@ -51,9 +51,9 @@ def __init__(self, *, grid=StructuredGrid("O32"), config=None): build_node_to_edge_connectivity(mesh) build_median_dual_mesh(mesh) - edges_per_node = max( - [mesh.nodes.edge_connectivity.cols(node) for node in range(0, fs_nodes.size)] - ) + edges_per_node = max([ + mesh.nodes.edge_connectivity.cols(node) for node in range(0, fs_nodes.size) + ]) self.mesh = mesh self.fs_edges = fs_edges diff --git a/tests/next_tests/past_common_fixtures.py b/tests/next_tests/past_common_fixtures.py index 3ac931f319..756d81b6d9 100644 --- a/tests/next_tests/past_common_fixtures.py +++ b/tests/next_tests/past_common_fixtures.py @@ -44,7 +44,7 @@ def identity(in_field: gtx.Field[[IDim], "float64"]) -> gtx.Field[[IDim], "float def make_tuple_op(): @gtx.field_operator() def make_tuple_op_impl( - inp: gtx.Field[[IDim], float64] + inp: gtx.Field[[IDim], float64], ) -> Tuple[gtx.Field[[IDim], float64], gtx.Field[[IDim], float64]]: return inp, inp diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index df49486cc2..805966a349 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -377,45 +377,37 @@ def test_cartesian_remap_implementation(): @pytest.mark.parametrize( "new_dims,field,expected_domain", [ - ( - ( - (D0,), - common._field( - np.arange(10), domain=common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) - ), - Domain(dims=(D0,), ranges=(UnitRange(0, 10),)), - ) - ), - ( - ( - (D0, D1), - common._field( - np.arange(10), domain=common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) - ), - Domain(dims=(D0, D1), ranges=(UnitRange(0, 10), UnitRange.infinite())), - ) - ), - ( - ( - (D0, D1), - common._field( - np.arange(10), domain=common.Domain(dims=(D1,), ranges=(UnitRange(0, 10),)) - ), - Domain(dims=(D0, D1), ranges=(UnitRange.infinite(), UnitRange(0, 10))), - ) - ), - ( - ( - (D0, D1, D2), - common._field( - np.arange(10), domain=common.Domain(dims=(D1,), ranges=(UnitRange(0, 10),)) - ), - Domain( - dims=(D0, D1, D2), - ranges=(UnitRange.infinite(), UnitRange(0, 10), UnitRange.infinite()), - ), - ) - ), + (( + (D0,), + common._field( + np.arange(10), domain=common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) + ), + Domain(dims=(D0,), ranges=(UnitRange(0, 10),)), + )), + (( + (D0, D1), + common._field( + np.arange(10), domain=common.Domain(dims=(D0,), ranges=(UnitRange(0, 10),)) + ), + Domain(dims=(D0, D1), ranges=(UnitRange(0, 10), UnitRange.infinite())), + )), + (( + (D0, D1), + common._field( + np.arange(10), domain=common.Domain(dims=(D1,), ranges=(UnitRange(0, 10),)) + ), + Domain(dims=(D0, D1), ranges=(UnitRange.infinite(), UnitRange(0, 10))), + )), + (( + (D0, D1, D2), + common._field( + np.arange(10), domain=common.Domain(dims=(D1,), ranges=(UnitRange(0, 10),)) + ), + Domain( + dims=(D0, D1, D2), + ranges=(UnitRange.infinite(), UnitRange(0, 10), UnitRange.infinite()), + ), + )), ], ) def test_field_broadcast(new_dims, field, expected_domain): @@ -761,12 +753,10 @@ def test_connectivity_field_inverse_image_2d_domain(): c2v_conn = common._connectivity( np.asarray([[0, 0, 2], [1, 1, 2], [2, 2, 2]]), - domain=common.domain( - [ - common.named_range((C, (C_START, C_STOP))), - common.named_range((C2V, (C2V_START, C2V_STOP))), - ] - ), + domain=common.domain([ + common.named_range((C, (C_START, C_STOP))), + common.named_range((C2V, (C2V_START, C2V_STOP))), + ]), codomain=V, ) @@ -842,12 +832,10 @@ def test_connectivity_field_inverse_image_2d_domain_skip_values(): c2v_conn = common._connectivity( np.asarray([[-1, 0, 2, -1], [1, 1, 2, 2], [2, 2, -1, -1], [-1, 2, -1, -1]]), - domain=common.domain( - [ - common.named_range((C, (C_START, C_STOP))), - common.named_range((C2V, (C2V_START, C2V_STOP))), - ] - ), + domain=common.domain([ + common.named_range((C, (C_START, C_STOP))), + common.named_range((C2V, (C2V_START, C2V_STOP))), + ]), codomain=V, skip_value=-1, ) diff --git a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py index 60a382d989..6840821d74 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py +++ b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py @@ -59,13 +59,11 @@ def test_str(loc_plain, message): def test_str_snippet(loc_snippet, message): - pattern = r"\n".join( - [ - f"{message}", - ' File ".*", line.*', - " # This very line of comment should be shown in the snippet.", - r" \^\^\^\^\^\^\^\^\^\^\^\^\^\^", - ] - ) + pattern = r"\n".join([ + f"{message}", + ' File ".*", line.*', + " # This very line of comment should be shown in the snippet.", + r" \^\^\^\^\^\^\^\^\^\^\^\^\^\^", + ]) s = str(errors.DSLError(loc_snippet, message)) assert re.match(pattern, s) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index a0035348ad..9ebd991e36 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -526,13 +526,15 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl def test_builtin_int_constructors(): - def int_constrs() -> tuple[ - int32, - int32, - int64, - int32, - int64, - ]: + def int_constrs() -> ( + tuple[ + int32, + int32, + int64, + int32, + int64, + ] + ): return 1, int32(1), int64(1), int32("1"), int64("1") parsed = FieldOperatorParser.apply_to_function(int_constrs) @@ -550,15 +552,17 @@ def int_constrs() -> tuple[ def test_builtin_float_constructors(): - def float_constrs() -> tuple[ - float, - float, - float32, - float64, - float, - float32, - float64, - ]: + def float_constrs() -> ( + tuple[ + float, + float, + float32, + float64, + float, + float32, + float64, + ] + ): return ( 0.1, float(0.1), diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 96ecc19c0b..e527d18e4c 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -35,6 +35,7 @@ arctan(), sqrt(), exp(), log(), isfinite(), isinf(), isnan(), floor(), ceil(), trunc() - evaluation test cases """ + import re import pytest diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 087dee815e..3f42307a7e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -54,18 +54,21 @@ def test_indent(): def test_cost(): - assert PrettyPrinter()._cost(["This is a single line."]) < PrettyPrinter()._cost( - ["These are", "multiple", "short", "lines."] - ) - assert PrettyPrinter()._cost(["This is a short line."]) < PrettyPrinter()._cost( - [ - "This is a very long line; longer than the maximum allowed line length. " - "So it should get a penalty for its length." - ] - ) - assert PrettyPrinter()._cost( - ["Equal length!", "Equal length!", "Equal length!"] - ) < PrettyPrinter()._cost(["Unequal length.", "Short…", "Looooooooooooooooooong…"]) + assert PrettyPrinter()._cost(["This is a single line."]) < PrettyPrinter()._cost([ + "These are", + "multiple", + "short", + "lines.", + ]) + assert PrettyPrinter()._cost(["This is a short line."]) < PrettyPrinter()._cost([ + "This is a very long line; longer than the maximum allowed line length. " + "So it should get a penalty for its length." + ]) + assert PrettyPrinter()._cost([ + "Equal length!", + "Equal length!", + "Equal length!", + ]) < PrettyPrinter()._cost(["Unequal length.", "Short…", "Looooooooooooooooooong…"]) def test_optimum(): diff --git a/tests/storage_tests/conftest.py b/tests/storage_tests/conftest.py index ead6bb60e8..55bf13afe9 100644 --- a/tests/storage_tests/conftest.py +++ b/tests/storage_tests/conftest.py @@ -14,7 +14,6 @@ """Global configuration of storage test generation and execution (with Hypothesis and pytest).""" - import hypothesis as hyp From f22c1497b64c63be714353c959a9d48b3d3fae0d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 6 Mar 2024 16:07:00 +0000 Subject: [PATCH 35/50] fix tests --- .../embedded_tests/test_nd_array_field.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 805966a349..94796dd444 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -497,11 +497,11 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): def test_absolute_indexing_dim_sliced(): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) - indexed_field_1 = field[JDim(8) : JDim(10), IDim(5) : IDim(9)] - expected = field[(IDim, UnitRange(5, 9)), (JDim, UnitRange(8, 10))] + indexed_field_1 = field[D1(8) : D1(10), D0(5) : D0(9)] + expected = field[(D0, UnitRange(5, 9)), (D1, UnitRange(8, 10))] assert common.is_field(indexed_field_1) assert indexed_field_1 == expected @@ -509,11 +509,11 @@ def test_absolute_indexing_dim_sliced(): def test_absolute_indexing_dim_sliced_single_slice(): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) - indexed_field_1 = field[KDim(11)] - indexed_field_2 = field[(KDim, 11)] + indexed_field_1 = field[D2(11)] + indexed_field_2 = field[(D2, 11)] assert common.is_field(indexed_field_1) assert indexed_field_1 == indexed_field_2 @@ -521,21 +521,21 @@ def test_absolute_indexing_dim_sliced_single_slice(): def test_absolute_indexing_wrong_dim_sliced(): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) - with pytest.raises(IndexError, match="Dimensions slicing mismatch between 'JDim' and 'IDim'."): - field[JDim(8) : IDim(10)] + with pytest.raises(IndexError, match="Dimensions slicing mismatch between 'D1' and 'D0'."): + field[D1(8) : D0(10)] def test_absolute_indexing_empty_dim_sliced(): domain = common.Domain( - dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) + dims=(D0, D1, D2), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) field = common._field(np.ones((5, 10, 15)), domain=domain) with pytest.raises(IndexError, match="Lower bound needs to be specified"): - field[: IDim(10)] + field[: D0(10)] def test_absolute_indexing_value_return(): From 80f21ff4cd4e4c95fd4716d08c4bbafe96eaf9ab Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 7 Mar 2024 09:06:21 +0000 Subject: [PATCH 36/50] describe algorithm --- src/gt4py/next/embedded/nd_array_field.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 57b51c69b1..4656f4220b 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -640,13 +640,16 @@ def _concat_where( ) mask_dim = mask_field.domain.dims[0] + # intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) + # compute the consecutive ranges (first relative, then domain) of true and false values mask_values, mask_relative_ranges = zip(*_compute_mask_ranges(mask_field.ndarray)) mask_domains = ( _relative_ranges_to_domain((relative_range,), mask_field.domain) for relative_range in mask_relative_ranges ) + # mask domains intersected with the respective fields intersected_domains = ( embedded_common.domain_intersection( t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain @@ -654,6 +657,7 @@ def _concat_where( for mask_value, mask_domain in zip(mask_values, mask_domains) ) + # remove the empty domains from the beginning and end mask_values, intersected_domains = tuple( zip(*_trim_empty_domains(list(zip(mask_values, intersected_domains)))) ) or ([], []) @@ -662,11 +666,13 @@ def _concat_where( f"In 'concat_where', cannot concatenate the following 'Domain's: {list(intersected_domains)}." ) + # slice the fields with the domain ranges transformed = [ t_broadcasted[d] if v else f_broadcasted[d] for v, d in zip(mask_values, intersected_domains) ] + # stack the fields together if transformed: return _concat(*transformed, dim=mask_dim) else: From c9dfc8d2e449e19a19884f9c4eb99de8b29d2c7c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 7 Mar 2024 09:08:39 +0000 Subject: [PATCH 37/50] add docstring to concat_where, but test not working --- src/gt4py/next/ffront/fbuiltins.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 10e364ab18..756cba85c4 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -88,6 +88,7 @@ class BuiltInFunction(Generic[_R, _P]): def __post_init__(self): object.__setattr__(self, "name", f"{self.function.__module__}.{self.function.__name__}") + object.__setattr__(self, "__doc__", self.function.__doc__) def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: impl = self.dispatch(*args) @@ -208,6 +209,26 @@ def concat_where( false_field: common.Field | core_defs.ScalarT | Tuple, /, ) -> common.Field | Tuple: + """ + Concatenates two field fields based on a 1D mask. + + The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields. + Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain. + + TODO: I can't get this doctest to run, even after copying the __doc__ in the decorator + Example: + >>> I = common.Dimension("I") + >>> mask = common._field([True, False, True], domain={I: (0, 3)}) + >>> true_field = common._field([1, 2], domain={I: (0, 2)}) + >>> false_field = common._field([3, 4, 5], domain={I: (1, 4)}) + >>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)}) + + >>> mask = common._field([True, False, True], domain={I: (0, 3)}) + >>> true_field = common._field([1, 2, 3], domain={I: (0, 3)}) + >>> false_field = common._field( + ... [4], domain={I: (2, 3)} + ... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values + """ raise NotImplementedError() From 14acc84f81744bf8438f28d3e3f80d4bc615af39 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 7 Mar 2024 09:41:44 +0000 Subject: [PATCH 38/50] add tests --- src/gt4py/next/embedded/common.py | 2 -- src/gt4py/next/embedded/nd_array_field.py | 2 +- .../unit_tests/embedded_tests/test_common.py | 36 +++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 4c2b480ad2..8acca2b154 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -97,7 +97,6 @@ def _absolute_sub_domain( return common.Domain(*named_ranges) -# TODO tests def domain_intersection( *domains: common.Domain, ) -> common.Domain: @@ -118,7 +117,6 @@ def domain_intersection( ) -# TODO tests def intersect_domains( *domains: common.Domain, ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None, diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 4656f4220b..6705155624 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -572,11 +572,11 @@ def _to_field( ) -# TODO move to common and test def _intersect_fields( *fields: common.Field | core_defs.Scalar, ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None, ) -> tuple[common.Field, ...]: + # TODO(havogt): this function could be moved to common, but then requires a broadcast implementation for all field implementations nd_array_class = _get_nd_array_class(*fields) promoted_dims = common.promote_dims(*[f.domain.dims for f in fields if common.is_field(f)]) broadcasted_fields = [_broadcast(_to_field(f, nd_array_class), promoted_dims) for f in fields] diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 91f15ee936..a06a6a7873 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -24,6 +24,8 @@ canonicalize_any_index_sequence, iterate_domain, sub_domain, + intersect_domains, + domain_intersection, ) @@ -180,3 +182,37 @@ def test_slicing(slices, expected): else: testee = canonicalize_any_index_sequence(slices) assert testee == expected + + +def test_domain_intersection(): + # see also tests in unit_tests/test_common.py for tests with 2 domains: `dom0 & dom1` + testee = (common.domain({I: (0, 5)}), common.domain({I: (1, 3)}), common.domain({I: (0, 3)})) + + result = domain_intersection(*testee) + + expected = testee[0] & testee[1] & testee[2] + assert result == expected + + +def test_intersect_domains(): + testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) + result = intersect_domains(*testee, ignore_dims=J) + + expected = (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) + assert result == expected + + +def test_intersect_domains_ignore_dims_none(): + testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) + result = intersect_domains(*testee) + + expected = (domain_intersection(*testee),) * 2 + assert result == expected + + +def test_intersect_domains_ignore_all_dims(): + testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) + result = intersect_domains(*testee, ignore_dims=(I, J)) + + expected = testee + assert result == expected From 071a9e00facfd48cd5f0c2bb4c5a876b3e1a1b98 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 7 Mar 2024 10:11:35 +0000 Subject: [PATCH 39/50] switch back to list[tuple] --- src/gt4py/next/embedded/nd_array_field.py | 40 ++++++++++++++--------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 6705155624..cf6797813b 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -18,7 +18,7 @@ import functools from collections.abc import Callable, Sequence from types import ModuleType -from typing import ClassVar +from typing import ClassVar, Iterable import numpy as np from numpy import typing as npt @@ -548,8 +548,11 @@ def _compute_mask_ranges( return res -def _trim_empty_domains(lst: list[tuple[bool, common.Domain]]) -> list[tuple[bool, common.Domain]]: +def _trim_empty_domains( + lst: Iterable[tuple[bool, common.Domain]], +) -> list[tuple[bool, common.Domain]]: """Remove empty domains from beginning and end of the list.""" + lst = list(lst) if not lst: return lst if lst[0][1].is_empty: @@ -644,32 +647,37 @@ def _concat_where( t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) # compute the consecutive ranges (first relative, then domain) of true and false values - mask_values, mask_relative_ranges = zip(*_compute_mask_ranges(mask_field.ndarray)) - mask_domains = ( - _relative_ranges_to_domain((relative_range,), mask_field.domain) - for relative_range in mask_relative_ranges + mask_values_to_relative_range_mapping: Iterable[tuple[bool, common.UnitRange]] = ( + _compute_mask_ranges(mask_field.ndarray) + ) + mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = ( + (mask, _relative_ranges_to_domain((relative_range,), mask_field.domain)) + for mask, relative_range in mask_values_to_relative_range_mapping ) # mask domains intersected with the respective fields - intersected_domains = ( - embedded_common.domain_intersection( - t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain + mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = ( + ( + mask_value, + embedded_common.domain_intersection( + t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain + ), ) - for mask_value, mask_domain in zip(mask_values, mask_domains) + for mask_value, mask_domain in mask_values_to_domain_mapping ) # remove the empty domains from the beginning and end - mask_values, intersected_domains = tuple( - zip(*_trim_empty_domains(list(zip(mask_values, intersected_domains)))) - ) or ([], []) - if any(d.is_empty for d in intersected_domains): + mask_values_to_intersected_domains_mapping = _trim_empty_domains( + mask_values_to_intersected_domains_mapping + ) + if any(d.is_empty for _, d in mask_values_to_intersected_domains_mapping): raise embedded_exceptions.NonContiguousDomain( - f"In 'concat_where', cannot concatenate the following 'Domain's: {list(intersected_domains)}." + f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}." ) # slice the fields with the domain ranges transformed = [ t_broadcasted[d] if v else f_broadcasted[d] - for v, d in zip(mask_values, intersected_domains) + for v, d in mask_values_to_intersected_domains_mapping ] # stack the fields together From a6186fc09052ff259358e4c2644e42d0919d0a1b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 11 Mar 2024 10:07:13 +0000 Subject: [PATCH 40/50] add unit_range.is_empty --- src/gt4py/next/common.py | 8 +++++++- tests/next_tests/unit_tests/test_common.py | 24 +++++++++++++++++----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index d8683a45d5..04c00e4a90 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -160,6 +160,12 @@ def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]: # classmethod since TypeGuards requires the guarded obj as separate argument return obj.start is not Infinity.NEGATIVE + @property + def is_empty(self) -> bool: + return ( + self.start == 0 and self.stop == 0 + ) # post_init ensures that empty is represented as UnitRange(0, 0) + def __repr__(self) -> str: return f"UnitRange({self.start}, {self.stop})" @@ -422,7 +428,7 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: @property def is_empty(self) -> bool: - return any(rng == UnitRange(0, 0) for rng in self.ranges) + return any(rng.is_empty for rng in self.ranges) @overload def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 7650e90c3c..7ce30c4bb2 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -14,7 +14,6 @@ import operator from typing import Optional, Pattern -import numpy as np import pytest from gt4py.next.common import ( @@ -22,7 +21,6 @@ DimensionKind, Domain, Infinity, - NamedRange, UnitRange, domain, named_range, @@ -92,10 +90,12 @@ def test_unbounded_max_min(value): assert min(Infinity.NEGATIVE, value) == Infinity.NEGATIVE -def test_empty_range(): +@pytest.mark.parametrize("empty_range", [UnitRange(1, 0), UnitRange(1, -1)]) +def test_empty_range(empty_range): expected = UnitRange(0, 0) - assert UnitRange(1, 1) == expected - assert UnitRange(1, -1) == expected + + assert empty_range == expected + assert empty_range.is_empty @pytest.fixture @@ -257,6 +257,20 @@ def test_domain_length(a_domain): assert len(a_domain) == 3 +@pytest.mark.parametrize( + "empty_domain, expected", + [ + (Domain(), False), + (Domain((IDim, UnitRange(0, 10))), False), + (Domain((IDim, UnitRange(0, 0))), True), + (Domain((IDim, UnitRange(0, 0)), (JDim, UnitRange(0, 1))), True), + (Domain((IDim, UnitRange(0, 1)), (JDim, UnitRange(0, 0))), True), + ], +) +def test_empty_domain(empty_domain, expected): + assert empty_domain.is_empty == expected + + @pytest.mark.parametrize( "domain_like", [ From 6413753578ac59175b5f78eeabd9e74c532cf499 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 11 Mar 2024 11:50:09 +0000 Subject: [PATCH 41/50] add more tests --- .../embedded_tests/test_nd_array_field.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 94796dd444..0975b81688 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -954,6 +954,20 @@ def test_hypercube(index_array, expected): ([6, 7, 8, 9, 10], None), ([42, 7, 42, 9, 42], None), ), + ( + # parts of mask_ranges are concatenated + ([True, True, False, False], None), + ([1, 2], {D0: (1, 3)}), + ([3, 4], {D0: (1, 3)}), + ([1, 4], {D0: (1, 3)}), + ), + ( + # parts of mask_ranges are concatenated and yield non-contiguous domain + ([True, False, True, False], None), + ([1, 2], {D0: (0, 2)}), + ([3, 4], {D0: (2, 4)}), + None, + ), ], ) def test_concat_where( From c798f97845a03fd8c5dba18b28f4652eb8fbb1f2 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 11 Mar 2024 13:57:06 +0000 Subject: [PATCH 42/50] address more review comments --- src/gt4py/next/embedded/common.py | 28 +++++++++---------- src/gt4py/next/embedded/exceptions.py | 2 ++ src/gt4py/next/embedded/nd_array_field.py | 16 +++++++---- .../unit_tests/embedded_tests/test_common.py | 13 ++++++--- 4 files changed, 35 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 8acca2b154..6b8be7bf45 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -104,11 +104,11 @@ def domain_intersection( Return the intersection of the given domains. Example: - >>> I = common.Dimension("I") - >>> domain_intersection( - ... common.domain({I: (0, 5)}), common.domain({I: (1, 3)}) - ... ) # doctest: +ELLIPSIS - Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),)) + >>> I = common.Dimension("I") + >>> domain_intersection( + ... common.domain({I: (0, 5)}), common.domain({I: (1, 3)}) + ... ) # doctest: +ELLIPSIS + Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),)) """ return functools.reduce( operator.and_, @@ -117,7 +117,7 @@ def domain_intersection( ) -def intersect_domains( +def restrict_to_intersection( *domains: common.Domain, ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None, ) -> tuple[common.Domain, ...]: @@ -125,14 +125,14 @@ def intersect_domains( Return the with each other intersected domains, ignoring 'ignore_dims' dimensions for the intersection. Example: - >>> I = common.Dimension("I") - >>> J = common.Dimension("J") - >>> res = intersect_domains( - ... common.domain({I: (0, 5), J: (1, 2)}), - ... common.domain({I: (1, 3), J: (0, 3)}), - ... ignore_dims=J, - ... ) - >>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)})) + >>> I = common.Dimension("I") + >>> J = common.Dimension("J") + >>> res = intersect_domains( + ... common.domain({I: (0, 5), J: (1, 2)}), + ... common.domain({I: (1, 3), J: (0, 3)}), + ... ignore_dims=J, + ... ) + >>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)})) """ ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,) intersection_without_ignore_dims = domain_intersection(*[ diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index bddea25712..93a65c1943 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -39,6 +39,8 @@ def __init__( class NonContiguousDomain(gt4py_exceptions.GT4PyError): + """Describes an error where a domain would become non-contiguous after an operation.""" + msg: str def __init__(self, msg: str): diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index cf6797813b..ffab70cd0b 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -540,7 +540,9 @@ def _compute_mask_ranges( ind = 0 res = [] for i in range(1, mask.shape[0]): - if (mask_i := bool(mask[i].item())) != cur: + if ( + mask_i := bool(mask[i].item()) + ) != cur: # `.item()` to extract the scalar from a 0-d array in case of e.g. cupy res.append((cur, common.UnitRange(ind, i))) cur = mask_i ind = i @@ -579,12 +581,13 @@ def _intersect_fields( *fields: common.Field | core_defs.Scalar, ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None, ) -> tuple[common.Field, ...]: - # TODO(havogt): this function could be moved to common, but then requires a broadcast implementation for all field implementations + # TODO(havogt): this function could be moved to common, but then requires a broadcast implementation for all field implementations; + # currently blocked, because requiring the `_to_field` function, see comment there. nd_array_class = _get_nd_array_class(*fields) - promoted_dims = common.promote_dims(*[f.domain.dims for f in fields if common.is_field(f)]) + promoted_dims = common.promote_dims(*(f.domain.dims for f in fields if common.is_field(f))) broadcasted_fields = [_broadcast(_to_field(f, nd_array_class), promoted_dims) for f in fields] - intersected_domains = embedded_common.intersect_domains( + intersected_domains = embedded_common.restrict_to_intersection( *[f.domain for f in broadcasted_fields], ignore_dims=ignore_dims ) @@ -597,7 +600,7 @@ def _intersect_fields( ) -def _concat_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]: +def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]: if not domains: return common.Domain() dim_start = domains[0][dim][1].start @@ -619,7 +622,7 @@ def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty ): raise ValueError("Fields to concatenate must not overlap.") - new_domain = _concat_domains(*[f.domain for f in fields], dim=dim) + new_domain = _stack_domains(*[f.domain for f in fields], dim=dim) if new_domain is None: raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.") nd_array_class = _get_nd_array_class(*fields) @@ -646,6 +649,7 @@ def _concat_where( # intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) + # TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils # compute the consecutive ranges (first relative, then domain) of true and false values mask_values_to_relative_range_mapping: Iterable[tuple[bool, common.UnitRange]] = ( _compute_mask_ranges(mask_field.ndarray) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index a06a6a7873..9765273f94 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -24,7 +24,7 @@ canonicalize_any_index_sequence, iterate_domain, sub_domain, - intersect_domains, + restrict_to_intersection, domain_intersection, ) @@ -194,9 +194,14 @@ def test_domain_intersection(): assert result == expected +def test_domain_intersection_empty(): + result = domain_intersection() + assert result == common.Domain() + + def test_intersect_domains(): testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) - result = intersect_domains(*testee, ignore_dims=J) + result = restrict_to_intersection(*testee, ignore_dims=J) expected = (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) assert result == expected @@ -204,7 +209,7 @@ def test_intersect_domains(): def test_intersect_domains_ignore_dims_none(): testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) - result = intersect_domains(*testee) + result = restrict_to_intersection(*testee) expected = (domain_intersection(*testee),) * 2 assert result == expected @@ -212,7 +217,7 @@ def test_intersect_domains_ignore_dims_none(): def test_intersect_domains_ignore_all_dims(): testee = (common.domain({I: (0, 5), J: (1, 2)}), common.domain({I: (1, 3), J: (1, 3)})) - result = intersect_domains(*testee, ignore_dims=(I, J)) + result = restrict_to_intersection(*testee, ignore_dims=(I, J)) expected = testee assert result == expected From 54428f7be1be726b46f4ba5558dc4675386c5061 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 11 Mar 2024 14:03:00 +0000 Subject: [PATCH 43/50] steal refactoring from nfarabullini/as_offset_embedded --- src/gt4py/next/ffront/experimental.py | 33 ++++++------------- .../ffront/foast_passes/type_deduction.py | 6 ++-- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index 39da80a5de..f53a75973a 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -12,30 +12,17 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass +from gt4py.next import common +from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset -from gt4py.next.type_system import type_specifications as ts +@BuiltInFunction +def as_offset( + offset_: FieldOffset, + field: common.Field, + /, +) -> common.ConnectivityField: + raise NotImplementedError() -@dataclass -class BuiltInFunction: - __gt_type: ts.FunctionType - def __call__(self, *args, **kwargs): - """Act as an empty place holder for the built in function.""" - - def __gt_type__(self): - return self.__gt_type - - -as_offset = BuiltInFunction( - ts.FunctionType( - pos_only_args=[ - ts.DeferredType(constraint=ts.OffsetType), - ts.DeferredType(constraint=ts.FieldType), - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=ts.DeferredType(constraint=ts.OffsetType), - ) -) +EXPERIMENTAL_FUN_BUILTIN_NAMES = ["as_offset"] diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index de23319b02..0e7c430596 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -20,6 +20,7 @@ from gt4py.next.common import DimensionKind from gt4py.next.ffront import ( # noqa dialect_ast_enums, + experimental, fbuiltins, type_info as ti_ffront, type_specifications as ts_ffront, @@ -727,7 +728,8 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: isinstance(new_func.type, ts.FunctionType) and not type_info.is_concrete(return_type) and isinstance(new_func, foast.Name) - and new_func.id in fbuiltins.FUN_BUILTIN_NAMES + and new_func.id + in (fbuiltins.FUN_BUILTIN_NAMES + experimental.EXPERIMENTAL_FUN_BUILTIN_NAMES) ): visitor = getattr(self, f"_visit_{new_func.id}") return visitor(new_node, **kwargs) @@ -941,8 +943,6 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: location=node.location, ) - _visit_concat_where = _visit_where - def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts From c178f2580a94ccbc52c5eb9191d9aac58b239be4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 11 Mar 2024 14:08:19 +0000 Subject: [PATCH 44/50] move to experimental --- src/gt4py/next/embedded/nd_array_field.py | 4 +-- src/gt4py/next/ffront/experimental.py | 37 +++++++++++++++++++++-- src/gt4py/next/ffront/fbuiltins.py | 30 ------------------ 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ffab70cd0b..8c7408133f 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -31,7 +31,7 @@ context as embedded_context, exceptions as embedded_exceptions, ) -from gt4py.next.ffront import fbuiltins +from gt4py.next.ffront import experimental, fbuiltins from gt4py.next.iterator import embedded as itir_embedded @@ -693,7 +693,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(fbuiltins.concat_where, _concat_where) # type: ignore[arg-type] # tuples are handled in the base implementation +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # tuples are handled in the base implementation def _make_reduction( diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index f53a75973a..0b615a60ea 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -12,8 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from typing import Tuple + +from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset +from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset, WhereBuiltinFunction @BuiltInFunction @@ -25,4 +28,34 @@ def as_offset( raise NotImplementedError() -EXPERIMENTAL_FUN_BUILTIN_NAMES = ["as_offset"] +@WhereBuiltinFunction +def concat_where( + mask: common.Field, + true_field: common.Field | core_defs.ScalarT | Tuple, + false_field: common.Field | core_defs.ScalarT | Tuple, + /, +) -> common.Field | Tuple: + """ + Concatenates two field fields based on a 1D mask. + + The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields. + Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain. + + TODO: I can't get this doctest to run, even after copying the __doc__ in the decorator + Example: + >>> I = common.Dimension("I") + >>> mask = common._field([True, False, True], domain={I: (0, 3)}) + >>> true_field = common._field([1, 2], domain={I: (0, 2)}) + >>> false_field = common._field([3, 4, 5], domain={I: (1, 4)}) + >>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)}) + + >>> mask = common._field([True, False, True], domain={I: (0, 3)}) + >>> true_field = common._field([1, 2, 3], domain={I: (0, 3)}) + >>> false_field = common._field( + ... [4], domain={I: (2, 3)} + ... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values + """ + raise NotImplementedError() + + +EXPERIMENTAL_FUN_BUILTIN_NAMES = ["as_offset", "concat_where"] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 756cba85c4..dc7bbf94e4 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -202,36 +202,6 @@ def where( raise NotImplementedError() -@WhereBuiltinFunction -def concat_where( - mask: common.Field, - true_field: common.Field | core_defs.ScalarT | Tuple, - false_field: common.Field | core_defs.ScalarT | Tuple, - /, -) -> common.Field | Tuple: - """ - Concatenates two field fields based on a 1D mask. - - The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields. - Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain. - - TODO: I can't get this doctest to run, even after copying the __doc__ in the decorator - Example: - >>> I = common.Dimension("I") - >>> mask = common._field([True, False, True], domain={I: (0, 3)}) - >>> true_field = common._field([1, 2], domain={I: (0, 2)}) - >>> false_field = common._field([3, 4, 5], domain={I: (1, 4)}) - >>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)}) - - >>> mask = common._field([True, False, True], domain={I: (0, 3)}) - >>> true_field = common._field([1, 2, 3], domain={I: (0, 3)}) - >>> false_field = common._field( - ... [4], domain={I: (2, 3)} - ... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values - """ - raise NotImplementedError() - - @BuiltInFunction def astype( value: common.Field | core_defs.ScalarT | Tuple, From a99561e8cfaa4cd99a340d5da250c8859332ec3e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 13 Mar 2024 16:18:45 +0000 Subject: [PATCH 45/50] address review comments --- src/gt4py/next/common.py | 2 -- src/gt4py/next/embedded/exceptions.py | 8 ++++---- src/gt4py/next/embedded/nd_array_field.py | 8 ++++---- src/gt4py/next/ffront/experimental.py | 24 +++++++++++----------- tests/next_tests/unit_tests/test_common.py | 4 ++-- 5 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 04c00e4a90..b8b8bf03e6 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -160,7 +160,6 @@ def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]: # classmethod since TypeGuards requires the guarded obj as separate argument return obj.start is not Infinity.NEGATIVE - @property def is_empty(self) -> bool: return ( self.start == 0 and self.stop == 0 @@ -426,7 +425,6 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: # classmethod since TypeGuards requires the guarded obj as separate argument return all(UnitRange.is_finite(rng) for rng in obj.ranges) - @property def is_empty(self) -> bool: return any(rng.is_empty for rng in self.ranges) diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index 93a65c1943..9306e9a002 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -41,8 +41,8 @@ def __init__( class NonContiguousDomain(gt4py_exceptions.GT4PyError): """Describes an error where a domain would become non-contiguous after an operation.""" - msg: str + detail: str - def __init__(self, msg: str): - super().__init__(f"Operation would result in a non-contiguous domain: `{msg}`.") - self.msg = msg + def __init__(self, detail: str): + super().__init__(f"Operation would result in a non-contiguous domain: `{detail}`.") + self.detail = detail diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 8c7408133f..370fe5f2ef 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -557,9 +557,9 @@ def _trim_empty_domains( lst = list(lst) if not lst: return lst - if lst[0][1].is_empty: + if lst[0][1].is_empty(): return _trim_empty_domains(lst[1:]) - if lst[-1][1].is_empty: + if lst[-1][1].is_empty(): return _trim_empty_domains(lst[:-1]) return lst @@ -619,7 +619,7 @@ def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: if ( len(fields) > 1 - and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty + and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty() ): raise ValueError("Fields to concatenate must not overlap.") new_domain = _stack_domains(*[f.domain for f in fields], dim=dim) @@ -673,7 +673,7 @@ def _concat_where( mask_values_to_intersected_domains_mapping = _trim_empty_domains( mask_values_to_intersected_domains_mapping ) - if any(d.is_empty for _, d in mask_values_to_intersected_domains_mapping): + if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping): raise embedded_exceptions.NonContiguousDomain( f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}." ) diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index 0b615a60ea..b69a118713 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -41,19 +41,19 @@ def concat_where( The resulting domain is the concatenation of the mask subdomains with the domains of the respective true or false fields. Empty domains at the beginning or end are ignored, but the interior must result in a consecutive domain. - TODO: I can't get this doctest to run, even after copying the __doc__ in the decorator + TODO(havogt): I can't get this doctest to run, even after copying the __doc__ in the decorator Example: - >>> I = common.Dimension("I") - >>> mask = common._field([True, False, True], domain={I: (0, 3)}) - >>> true_field = common._field([1, 2], domain={I: (0, 2)}) - >>> false_field = common._field([3, 4, 5], domain={I: (1, 4)}) - >>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)}) - - >>> mask = common._field([True, False, True], domain={I: (0, 3)}) - >>> true_field = common._field([1, 2, 3], domain={I: (0, 3)}) - >>> false_field = common._field( - ... [4], domain={I: (2, 3)} - ... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values + >>> I = common.Dimension("I") + >>> mask = common._field([True, False, True], domain={I: (0, 3)}) + >>> true_field = common._field([1, 2], domain={I: (0, 2)}) + >>> false_field = common._field([3, 4, 5], domain={I: (1, 4)}) + >>> assert concat_where(mask, true_field, false_field) == _field([1, 3], domain={I: (0, 2)}) + + >>> mask = common._field([True, False, True], domain={I: (0, 3)}) + >>> true_field = common._field([1, 2, 3], domain={I: (0, 3)}) + >>> false_field = common._field( + ... [4], domain={I: (2, 3)} + ... ) # error because of non-consecutive domain: missing I(1), but has I(0) and I(2) values """ raise NotImplementedError() diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 7ce30c4bb2..ce940131c3 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -95,7 +95,7 @@ def test_empty_range(empty_range): expected = UnitRange(0, 0) assert empty_range == expected - assert empty_range.is_empty + assert empty_range.is_empty() @pytest.fixture @@ -268,7 +268,7 @@ def test_domain_length(a_domain): ], ) def test_empty_domain(empty_domain, expected): - assert empty_domain.is_empty == expected + assert empty_domain.is_empty() == expected @pytest.mark.parametrize( From 5755d6203c9ca80fdfa8601ebd48d673cf944cf7 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 13 Mar 2024 19:04:57 +0000 Subject: [PATCH 46/50] add a test (wip) --- src/gt4py/next/ffront/fbuiltins.py | 3 - .../ffront/foast_passes/type_deduction.py | 2 + src/gt4py/next/ffront/foast_to_itir.py | 6 +- .../ffront_tests/test_concat_where.py | 56 +++++++++++++++++++ 4 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index dc7bbf94e4..24c6a07ae3 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -25,7 +25,6 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, embedded from gt4py.next.common import Dimension, Field # noqa: F401 [unused-import] for TYPE_BUILTINS -from gt4py.next.ffront.experimental import as_offset # noqa: F401 [unused-import] from gt4py.next.iterator import runtime from gt4py.next.type_system import type_specifications as ts @@ -295,9 +294,7 @@ def impl( "min_over", "broadcast", "where", - "concat_where", "astype", - "as_offset", *MATH_BUILTIN_NAMES, ] diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 0e7c430596..db7477d8e4 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -943,6 +943,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: location=node.location, ) + _visit_concat_where = _visit_where + def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: arg_type = cast(ts.FieldType | ts.ScalarType, node.args[0].type) broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 84b0966452..70e01eb7b0 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -24,6 +24,7 @@ lowering_utils, type_specifications as ts_ffront, ) +from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind from gt4py.next.iterator import ir as itir @@ -309,7 +310,10 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: return self._visit_shift(node, **kwargs) elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: return self._visit_math_built_in(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in FUN_BUILTIN_NAMES: + elif ( + isinstance(node.func, foast.Name) + and node.func.id in FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES + ): visitor = getattr(self, f"_visit_{node.func.id}") return visitor(node, **kwargs) elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py new file mode 100644 index 0000000000..246f6bc31d --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -0,0 +1,56 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +import pytest +from next_tests.integration_tests.cases import ( + C2E, + E2V, + V2E, + E2VDim, + IDim, + Ioff, + JDim, + KDim, + Koff, + V2EDim, + Vertex, + cartesian_case, + unstructured_case, +) +from gt4py import next as gtx +from gt4py.next.ffront.experimental import concat_where +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, +) + + +def test_boundary_same_size_fields(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where(k == 0, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.zeros(interior.shape) + ref[:, :, 0] = boundary.asnumpy()[:, :, 0] + ref[:, :, 1:] = interior.asnumpy()[:, :, 1:] + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) From 4ebf2d61555f2932ab62c477f3fc972ebaaef9e8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 13 Mar 2024 19:08:35 +0000 Subject: [PATCH 47/50] add todos --- .../feature_tests/ffront_tests/test_concat_where.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index 246f6bc31d..d8d39e7ccc 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -54,3 +54,9 @@ def testee( ref[:, :, 1:] = interior.asnumpy()[:, :, 1:] cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +# TODO: +# - IJField as boundary +# - IJKField with 1 level as boundary +# - mask that contains multiple regions of true/false From 2c89af31a796d7d575d75c05ae3c57e3529f8071 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 14 Mar 2024 09:57:30 +0000 Subject: [PATCH 48/50] fix refactoring bugs --- src/gt4py/next/common.py | 2 +- src/gt4py/next/embedded/common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index b8b8bf03e6..dff5bb1f06 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -426,7 +426,7 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: return all(UnitRange.is_finite(rng) for rng in obj.ranges) def is_empty(self) -> bool: - return any(rng.is_empty for rng in self.ranges) + return any(rng.is_empty() for rng in self.ranges) @overload def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 6b8be7bf45..f09e1b1ac3 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -127,7 +127,7 @@ def restrict_to_intersection( Example: >>> I = common.Dimension("I") >>> J = common.Dimension("J") - >>> res = intersect_domains( + >>> res = restrict_to_intersection( ... common.domain({I: (0, 5), J: (1, 2)}), ... common.domain({I: (1, 3), J: (0, 3)}), ... ignore_dims=J, From 8b268ed1f14199cb2ce1bff11732a4a588bfc264 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 14 Mar 2024 11:56:44 +0000 Subject: [PATCH 49/50] add tests in field_operators --- src/gt4py/next/embedded/nd_array_field.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 6 +- src/gt4py/next/ffront/foast_to_itir.py | 8 +- .../ffront_tests/test_concat_where.py | 120 +++++++++++++++--- .../ffront_tests/test_execution.py | 24 ---- .../feature_tests/ffront_tests/test_where.py | 118 +++++++++++++++++ 6 files changed, 229 insertions(+), 49 deletions(-) create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 370fe5f2ef..ac60f127e1 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -693,7 +693,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] # tuples are handled in the base implementation +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) def _make_reduction( diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 24c6a07ae3..8d9552ab6d 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -58,6 +58,10 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.FieldType elif t is common.Dimension: return ts.DimensionType + elif t is FieldOffset: + return ts.OffsetType + elif t is common.ConnectivityField: + return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType elif t is type: @@ -147,7 +151,7 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: raise ValueError( "Tuple of different size not allowed." ) # TODO(havogt) find a strategy to unify parsing and embedded error messages - return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return tuple(self(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(mask, true_field, false_field) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 70e01eb7b0..fbd879ca7c 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -310,9 +310,8 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: return self._visit_shift(node, **kwargs) elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: return self._visit_math_built_in(node, **kwargs) - elif ( - isinstance(node.func, foast.Name) - and node.func.id in FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES + elif isinstance(node.func, foast.Name) and node.func.id in ( + FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES ): visitor = getattr(self, f"_visit_{node.func.id}") return visitor(node, **kwargs) @@ -371,8 +370,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: node.type, ) - def _visit_concat_where(self, node: foast.Call, **kwargs) -> itir.FunCall: - return self._map("if_", *node.args) + _visit_concat_where = _visit_where def _visit_broadcast(self, node: foast.Call, **kwargs) -> itir.FunCall: return self.visit(node.args[0], **kwargs) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py index d8d39e7ccc..9da6d260e5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_concat_where.py @@ -13,21 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later import numpy as np +from typing import Tuple import pytest from next_tests.integration_tests.cases import ( - C2E, - E2V, - V2E, - E2VDim, - IDim, - Ioff, - JDim, KDim, - Koff, - V2EDim, - Vertex, cartesian_case, - unstructured_case, ) from gt4py import next as gtx from gt4py.next.ffront.experimental import concat_where @@ -49,14 +39,108 @@ def testee( boundary = cases.allocate(cartesian_case, testee, "boundary")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - ref = np.zeros(interior.shape) - ref[:, :, 0] = boundary.asnumpy()[:, :, 0] - ref[:, :, 1:] = interior.asnumpy()[:, :, 1:] + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy(), interior.asnumpy() + ) cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) -# TODO: -# - IJField as boundary -# - IJKField with 1 level as boundary -# - mask that contains multiple regions of true/false +def test_boundary_horizontal_slice(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJField + ) -> cases.IJKField: + return concat_where(k == 0, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary.asnumpy()[:, :, np.newaxis], + interior.asnumpy(), + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_boundary_single_layer(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return concat_where(k == 0, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary", sizes={KDim: 1})() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + np.broadcast_to(boundary.asnumpy(), interior.shape), + interior.asnumpy(), + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +def test_alternating_mask(cartesian_case): + @gtx.field_operator + def testee(k: cases.KField, f0: cases.IJKField, f1: cases.IJKField) -> cases.IJKField: + return concat_where(k % 2 == 0, f1, f0) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + f0 = cases.allocate(cartesian_case, testee, "f0")() + f1 = cases.allocate(cartesian_case, testee, "f1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where(k.asnumpy()[np.newaxis, np.newaxis, :] % 2 == 0, f1.asnumpy(), f0.asnumpy()) + + cases.verify(cartesian_case, testee, k, f0, f1, out=out, ref=ref) + + +@pytest.mark.uses_tuple_returns +def test_with_tuples(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, + interior0: cases.IJKField, + boundary0: cases.IJField, + interior1: cases.IJKField, + boundary1: cases.IJField, + ) -> Tuple[cases.IJKField, cases.IJKField]: + return concat_where(k == 0, (boundary0, boundary1), (interior0, interior1)) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior0 = cases.allocate(cartesian_case, testee, "interior0")() + boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() + interior1 = cases.allocate(cartesian_case, testee, "interior1")() + boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref0 = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary0.asnumpy()[:, :, np.newaxis], + interior0.asnumpy(), + ) + ref1 = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary1.asnumpy()[:, :, np.newaxis], + interior1.asnumpy(), + ) + + cases.verify( + cartesian_case, + testee, + k, + interior0, + boundary0, + interior1, + boundary1, + out=out, + ref=(ref0, ref1), + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 5db9886966..905975b80b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -30,7 +30,6 @@ int64, minimum, neighbor_sum, - where, ) from gt4py.next.ffront.experimental import as_offset from gt4py.next.program_processors.runners import gtfn @@ -1053,29 +1052,6 @@ def program_domain_tuple( ) -@pytest.mark.uses_cartesian_shift -def test_where_k_offset(cartesian_case): - @gtx.field_operator - def fieldop_where_k_offset( - inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType] - ) -> cases.IKField: - return where(k_index > 0, inp(Koff[-1]), 2) - - @gtx.program - def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cases.IKField): - fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, 10), KDim: (1, 10)}) - - inp = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() - k_index = cases.allocate( - cartesian_case, fieldop_where_k_offset, "k_index", strategy=cases.IndexInitializer() - )() - out = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() - - ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), out.asnumpy()) - - cases.verify(cartesian_case, prog, inp, k_index, out=out, ref=ref) - - def test_undefined_symbols(cartesian_case): with pytest.raises(errors.DSLError, match="Undeclared symbol"): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py new file mode 100644 index 0000000000..2fc31e6574 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_where.py @@ -0,0 +1,118 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +from typing import Tuple +import pytest +from next_tests.integration_tests.cases import ( + IDim, + JDim, + KDim, + Koff, + cartesian_case, +) +from gt4py import next as gtx +from gt4py.next.ffront.fbuiltins import where, broadcast +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + exec_alloc_descriptor, +) + + +@pytest.mark.uses_cartesian_shift +def test_where_k_offset(cartesian_case): + @gtx.field_operator + def fieldop_where_k_offset( + inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType] + ) -> cases.IKField: + return where(k_index > 0, inp(Koff[-1]), 2) + + @gtx.program + def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cases.IKField): + fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, 10), KDim: (1, 10)}) + + inp = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() + k_index = cases.allocate( + cartesian_case, fieldop_where_k_offset, "k_index", strategy=cases.IndexInitializer() + )() + out = cases.allocate(cartesian_case, fieldop_where_k_offset, cases.RETURN)() + + ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), out.asnumpy()) + + cases.verify(cartesian_case, prog, inp, k_index, out=out, ref=ref) + + +def test_same_size_fields(cartesian_case): + # Note boundaries can only be implemented with `where` if both fields have the same size, see `concat_where` + @gtx.field_operator + def testee( + k: cases.KField, interior: cases.IJKField, boundary: cases.IJKField + ) -> cases.IJKField: + return where(k == 0, boundary, interior) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior = cases.allocate(cartesian_case, testee, "interior")() + boundary = cases.allocate(cartesian_case, testee, "boundary")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, boundary.asnumpy(), interior.asnumpy() + ) + + cases.verify(cartesian_case, testee, k, interior, boundary, out=out, ref=ref) + + +@pytest.mark.uses_tuple_returns +def test_with_tuples(cartesian_case): + @gtx.field_operator + def testee( + k: cases.KField, + interior0: cases.IJKField, + boundary0: cases.IJField, + interior1: cases.IJKField, + boundary1: cases.IJField, + ) -> Tuple[cases.IJKField, cases.IJKField]: + return where( + broadcast(k, (IDim, JDim, KDim)) == 0, (boundary0, boundary1), (interior0, interior1) + ) + + k = cases.allocate(cartesian_case, testee, "k", strategy=cases.IndexInitializer())() + interior0 = cases.allocate(cartesian_case, testee, "interior0")() + boundary0 = cases.allocate(cartesian_case, testee, "boundary0")() + interior1 = cases.allocate(cartesian_case, testee, "interior1")() + boundary1 = cases.allocate(cartesian_case, testee, "boundary1")() + out = cases.allocate(cartesian_case, testee, cases.RETURN)() + + ref0 = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary0.asnumpy()[:, :, np.newaxis], + interior0.asnumpy(), + ) + ref1 = np.where( + k.asnumpy()[np.newaxis, np.newaxis, :] == 0, + boundary1.asnumpy()[:, :, np.newaxis], + interior1.asnumpy(), + ) + + cases.verify( + cartesian_case, + testee, + k, + interior0, + boundary0, + interior1, + boundary1, + out=out, + ref=(ref0, ref1), + ) From c2738b61ad368b9917ed5dccbeeb1b5ed88c3334 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 14 Mar 2024 12:59:58 +0000 Subject: [PATCH 50/50] add type ignore --- src/gt4py/next/embedded/nd_array_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ac60f127e1..5a07328531 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -693,7 +693,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] def _make_reduction(