Skip to content

Commit

Permalink
Run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N committed Dec 31, 2024
1 parent 7fbd083 commit 827a4d3
Show file tree
Hide file tree
Showing 13 changed files with 7 additions and 23 deletions.
1 change: 0 additions & 1 deletion src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from gt4py.next.type_system import type_specifications as ts, type_translation



_P = ParamSpec("_P")
_R = TypeVar("_R")

Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,6 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> foast.Call:
f"Invalid call to 'astype': second argument must be a scalar type, got '{new_type}'.",
)


return_type = type_info.type_tree_map(
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
Expand Down
4 changes: 0 additions & 4 deletions src/gt4py/next/field_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

from types import ModuleType

from types import ModuleType
from typing import Callable

import numpy as np

from gt4py._core import definitions as core_defs
Expand Down Expand Up @@ -61,7 +58,6 @@ def impl(type_: ts.ScalarType) -> common.MutableField:
return impl(type_)



def get_array_ns(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...],
) -> ModuleType:
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,7 @@ def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProvider
def set_at(expr: common.Field, domain: common.DomainLike, target: common.MutableField) -> None:
operators._tuple_assign_field(target, expr, common.domain(domain))


@runtime.if_stmt.register(EMBEDDED)
def if_stmt(cond: bool, true_branch: Callable[[], None], false_branch: Callable[[], None]) -> None:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

DimensionKind = common.DimensionKind


@noninstantiable
class Node(eve.Node):
location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False)
Expand Down Expand Up @@ -172,6 +173,7 @@ class FunctionDefinition(Node, SymbolTableTrait):
*TYPEBUILTINS,
}


class Stmt(Node): ...


Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/iterator/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from gt4py.next.program_processors import program_formatter



if TYPE_CHECKING:
# TODO(tehrengruber): remove cirular dependency and import unconditionally
from gt4py.next import backend as next_backend
Expand Down
10 changes: 0 additions & 10 deletions src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,6 @@ def add_fundef(cls, fun):
def add_stmt(cls, stmt):
cls.body.append(stmt)

@classmethod
def add_stmt(cls, stmt):
cls.body.append(stmt)

def __enter__(self):
iterator.builtins.builtin_dispatch.push_key(TRACING)

Expand Down Expand Up @@ -251,11 +247,6 @@ def if_stmt(
)


@iterator.runtime.set_at.register(TRACING)
def set_at(expr, domain, target):
TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target))


def _contains_tuple_dtype_field(arg):
if isinstance(arg, tuple):
return any(_contains_tuple_dtype_field(el) for el in arg)
Expand Down Expand Up @@ -302,7 +293,6 @@ def _make_program_params(fun, args) -> list[Sym]:
return params



def trace_fencil_definition(fun: typing.Callable, args: typing.Iterable) -> itir.Program:
"""
Transform fencil given as a callable into `itir.Program` using tracing.
Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def _transform_if(
return None



def _transform_by_pattern(
stmt: itir.Stmt,
predicate: Callable[[itir.Expr, int], bool],
Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def apply_common_transforms(
#: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for
#: more details.
symbolic_domain_sizes: Optional[dict[str, str]] = None,

offset_provider_type: Optional[common.OffsetProviderType] = None,
) -> itir.Program:
# TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps
Expand Down
1 change: 0 additions & 1 deletion src/gt4py/next/program_processors/runners/roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from gt4py.eve import codegen
from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako
from gt4py.next import allocators as next_allocators, backend as next_backend, common, config

from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir
from gt4py.next.iterator import ir as itir, transforms as itir_transforms
from gt4py.next.otf import stages, workflow
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,4 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]:

raise TypeError(
"tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`, or with a function and collection."
)
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# GT4Py - GridTools Framework
#

# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
Expand Down Expand Up @@ -29,6 +28,7 @@
I = gtx.Dimension("I")
Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,))


@fundef
def copy_stencil(inp):
return deref(inp)
Expand All @@ -54,6 +54,7 @@ def test_prog(program_processor):
if validate:
assert np.allclose(inp.asnumpy(), out.asnumpy())


@fendef
def index_program_simple(out, size):
set_at(
Expand Down
2 changes: 1 addition & 1 deletion tests/next_tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_apply_to_primitive_constituents():
)(tuple_type)

prim = type_info.apply_to_primitive_constituents(
lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type), tuple_type
lambda primitive_type: ts.FieldType(dims=[], dtype=primitive_type), tuple_type
)

assert tree == prim
Expand Down

0 comments on commit 827a4d3

Please sign in to comment.