Skip to content

Commit

Permalink
refactor[next]: Use type specification for itir.Literal (GridTools#1529)
Browse files Browse the repository at this point in the history
Small PR in preparation of the new ITIR type system. Currently the type of a `itir.Literal` is stored as a string which blocks introducing a `type: ts.TypeSpecification` attribute in all `itir.Node`s. In order to keep the PR for the new type inference easy to review this  has been factored out.

```python
class Literal(Expr):
    value: str
    type: str

    @datamodels.validator("type")
    def _type_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value):
        if value not in TYPEBUILTINS:
            raise ValueError(f"'{value}' is not a valid builtin type.")
```
is changed to
```python
class Literal(Expr):
    value: str
    type: ts.ScalarType
```
  • Loading branch information
tehrengruber authored Apr 17, 2024
1 parent a603bfe commit 9f44142
Show file tree
Hide file tree
Showing 21 changed files with 134 additions and 95 deletions.
4 changes: 2 additions & 2 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def _construct_itir_domain_arg(
else:
lower = self._visit_slice_bound(
slices[dim_i].lower if slices else None,
itir.Literal(value="0", type=itir.INTEGER_INDEX_BUILTIN),
im.literal("0", itir.INTEGER_INDEX_BUILTIN),
dim_size,
)
upper = self._visit_slice_bound(
Expand Down Expand Up @@ -458,7 +458,7 @@ def visit_Constant(self, node: past.Constant, **kwargs: Any) -> itir.Literal:
f"Scalars of kind '{node.type.kind}' not supported currently."
)
typename = node.type.kind.name.lower()
return itir.Literal(value=str(node.value), type=typename)
return im.literal(str(node.value), typename)

raise NotImplementedError("Only scalar literals supported currently.")

Expand Down
8 changes: 2 additions & 6 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gt4py.eve.concepts import SourceLocation
from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait
from gt4py.eve.utils import noninstantiable
from gt4py.next.type_system import type_specifications as ts


# TODO(havogt):
Expand Down Expand Up @@ -73,12 +74,7 @@ class Expr(Node): ...

class Literal(Expr):
value: str
type: str

@datamodels.validator("type")
def _type_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value):
if value not in TYPEBUILTINS:
raise ValueError(f"'{value}' is not a valid builtin type.")
type: ts.ScalarType


class NoneLiteral(Expr):
Expand Down
25 changes: 16 additions & 9 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti
SymRef(id=SymbolRef('a'))
>>> ensure_expr(3)
Literal(value='3', type='int32')
Literal(value='3', type=ScalarType(kind=<ScalarKind.INT32: 32>, shape=None))
>>> ensure_expr(itir.OffsetLiteral(value="i"))
OffsetLiteral(value='i')
Expand Down Expand Up @@ -94,6 +94,13 @@ def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.Of
return str_or_offset


def ensure_type(type_: str | ts.TypeSpec | None) -> ts.TypeSpec | None:
if isinstance(type_, str):
return ts.ScalarType(kind=getattr(ts.ScalarKind, type_.upper()))
assert isinstance(type_, ts.TypeSpec) or type_ is None
return type_


class lambda_:
"""
Create a lambda from params and an expression.
Expand All @@ -118,7 +125,7 @@ class call:
Examples
--------
>>> call("plus")(1, 1)
FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type='int32'), Literal(value='1', type='int32')])
FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type=ScalarType(kind=<ScalarKind.INT32: 32>, shape=None)), Literal(value='1', type=ScalarType(kind=<ScalarKind.INT32: 32>, shape=None))])
"""

def __init__(self, expr):
Expand Down Expand Up @@ -291,22 +298,22 @@ def shift(offset, value=None):
return call(call("shift")(*args))


def literal(value: str, typename: str):
return itir.Literal(value=value, type=typename)
def literal(value: str, typename: str) -> itir.Literal:
return itir.Literal(value=value, type=ensure_type(typename))


def literal_from_value(val: core_defs.Scalar) -> itir.Literal:
"""
Make a literal node from a value.
>>> literal_from_value(1.0)
Literal(value='1.0', type='float64')
Literal(value='1.0', type=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None))
>>> literal_from_value(1)
Literal(value='1', type='int32')
Literal(value='1', type=ScalarType(kind=<ScalarKind.INT32: 32>, shape=None))
>>> literal_from_value(2147483648)
Literal(value='2147483648', type='int64')
Literal(value='2147483648', type=ScalarType(kind=<ScalarKind.INT64: 64>, shape=None))
>>> literal_from_value(True)
Literal(value='True', type='bool')
Literal(value='True', type=ScalarType(kind=<ScalarKind.BOOL: 1>, shape=None))
"""
if not isinstance(val, core_defs.Scalar): # type: ignore[arg-type] # mypy bug #11673
raise ValueError(f"Value must be a scalar, got '{type(val).__name__}'.")
Expand All @@ -321,7 +328,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal:
typename = type_spec.kind.name.lower()
assert typename in itir.TYPEBUILTINS

return itir.Literal(value=str(val), type=typename)
return literal(str(val), typename)


def neighbors(offset, it):
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ def SYM(self, value: lark_lexer.Token) -> ir.Sym:

def SYM_REF(self, value: lark_lexer.Token) -> Union[ir.SymRef, ir.Literal]:
if value.value in ("True", "False"):
return ir.Literal(value=value.value, type="bool")
return im.literal(value.value, "bool")
return ir.SymRef(id=value.value)

def INT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal:
return im.literal_from_value(int(value.value))

def FLOAT_LITERAL(self, value: lark_lexer.Token) -> ir.Literal:
return ir.Literal(value=value.value, type="float64")
return im.literal(value.value, "float64")

def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral:
v: Union[int, str] = value.value[:-1]
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let
from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda
from gt4py.next.type_system import type_info


class UnknownLength:
Expand Down Expand Up @@ -232,7 +233,7 @@ def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[
and isinstance(node.args[0], ir.Literal)
):
# `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i`
assert node.args[0].type in ir.INTEGER_BUILTINS
assert type_info.is_integer(node.args[0].type)
make_tuple_call = node.args[1]
idx = int(node.args[0].value)
assert idx < len(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/inline_lifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def visit_FunCall(
assert len(node.args[0].fun.args) == 1
args = node.args[0].args
if len(args) == 0:
return ir.Literal(value="True", type="bool")
return im.literal_from_value(True)

res = ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[args[0]])
for arg in args[1:]:
Expand Down
8 changes: 3 additions & 5 deletions src/gt4py/next/iterator/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries
from gt4py.next.type_inference import Type, TypeVar, freshen, reindex_vars, unify
from gt4py.next.type_system import type_info


"""Constraint-based inference for the iterator IR."""
Expand Down Expand Up @@ -643,7 +644,7 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type:
return TypeVar.fresh()

def visit_Literal(self, node: ir.Literal, **kwargs) -> Val:
return Val(kind=Value(), dtype=Primitive(name=node.type))
return Val(kind=Value(), dtype=Primitive(name=node.type.kind.name.lower()))

def visit_AxisLiteral(self, node: ir.AxisLiteral, **kwargs) -> Val:
return Val(kind=Value(), dtype=AXIS_DTYPE, size=Scalar())
Expand Down Expand Up @@ -672,10 +673,7 @@ def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type:
# Calls to `tuple_get` are handled as being part of the grammar, not as function calls.
if len(node.args) != 2:
raise TypeError("'tuple_get' requires exactly two arguments.")
if (
not isinstance(node.args[0], ir.Literal)
or node.args[0].type != ir.INTEGER_INDEX_BUILTIN
):
if not isinstance(node.args[0], ir.Literal) or not type_info.is_integer(node.args[0].type):
raise TypeError(
f"The first argument to 'tuple_get' must be a literal of type '{ir.INTEGER_INDEX_BUILTIN}'."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
UnstructuredDomain,
)
from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef
from gt4py.next.type_system import type_info


def pytype_to_cpptype(t: str) -> Optional[str]:
Expand Down Expand Up @@ -183,7 +184,7 @@ def _collect_offset_definitions(


def _literal_as_integral_constant(node: itir.Literal) -> IntegralConstant:
assert node.type in itir.INTEGER_BUILTINS
assert type_info.is_integer(node.type)
return IntegralConstant(value=int(node.value))


Expand All @@ -193,7 +194,7 @@ def _is_scan(node: itir.Node) -> TypeGuard[itir.FunCall]:

def _bool_from_literal(node: itir.Node) -> bool:
assert isinstance(node, itir.Literal)
assert node.type == "bool" and node.value in ("True", "False")
assert type_info.is_logical(node.type) and node.value in ("True", "False")
return node.value == "True"


Expand Down Expand Up @@ -296,7 +297,7 @@ def visit_Lambda(
)

def visit_Literal(self, node: itir.Literal, **kwargs: Any) -> Literal:
return Literal(value=node.value, type=node.type)
return Literal(value=node.value, type=node.type.kind.name.lower())

def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs: Any) -> OffsetLiteral:
return OffsetLiteral(value=node.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from gt4py.next.common import Connectivity
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef
from gt4py.next.type_system import type_specifications as ts, type_translation as tt
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt

from .itir_to_tasklet import (
Context,
Expand Down Expand Up @@ -64,7 +64,7 @@ def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
"""
stencil_fobj = cast(FunCall, stencil)
is_forward = stencil_fobj.args[1]
assert isinstance(is_forward, Literal) and is_forward.type == "bool"
assert isinstance(is_forward, Literal) and type_info.is_logical(is_forward.type)
init_carry = stencil_fobj.args[2]
assert isinstance(init_carry, Literal)
return is_forward.value == "True", init_carry
Expand Down
21 changes: 20 additions & 1 deletion src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,25 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool:
return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64]


def is_integer(symbol_type: ts.TypeSpec) -> bool:
"""
Check if ``symbol_type`` is an integral type.
Examples:
---------
>>> is_integer(ts.ScalarType(kind=ts.ScalarKind.INT32))
True
>>> is_integer(ts.ScalarType(kind=ts.ScalarKind.FLOAT32))
False
>>> is_integer(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)))
False
"""
return isinstance(symbol_type, ts.ScalarType) and symbol_type.kind in {
ts.ScalarKind.INT32,
ts.ScalarKind.INT64,
}


def is_integral(symbol_type: ts.TypeSpec) -> bool:
"""
Check if the dtype of ``symbol_type`` is an integral type.
Expand All @@ -236,7 +255,7 @@ def is_integral(symbol_type: ts.TypeSpec) -> bool:
>>> is_integral(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)))
True
"""
return extract_dtype(symbol_type).kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64]
return is_integer(extract_dtype(symbol_type))


def is_number(symbol_type: ts.TypeSpec) -> bool:
Expand Down
19 changes: 16 additions & 3 deletions tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from gt4py.next.ffront.func_to_past import ProgramParser
from gt4py.next.ffront.past_to_itir import ProgramLowering
from gt4py.next.iterator import ir as itir
from gt4py.next.type_system import type_specifications as ts

from next_tests.past_common_fixtures import (
IDim,
Expand Down Expand Up @@ -59,7 +60,7 @@ def test_copy_lowering(copy_program_def, itir_identity_fundef):
fun=P(itir.SymRef, id=eve.SymbolRef("named_range")),
args=[
P(itir.AxisLiteral, value="IDim"),
P(itir.Literal, value="0", type="int32"),
P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)),
P(itir.SymRef, id=eve.SymbolRef("__out_size_0")),
],
)
Expand Down Expand Up @@ -118,8 +119,20 @@ def test_copy_restrict_lowering(copy_restrict_program_def, itir_identity_fundef)
fun=P(itir.SymRef, id=eve.SymbolRef("named_range")),
args=[
P(itir.AxisLiteral, value="IDim"),
P(itir.Literal, value="1", type=itir.INTEGER_INDEX_BUILTIN),
P(itir.Literal, value="2", type=itir.INTEGER_INDEX_BUILTIN),
P(
itir.Literal,
value="1",
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
),
P(
itir.Literal,
value="2",
type=ts.ScalarType(
kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())
),
),
],
)
],
Expand Down
11 changes: 6 additions & 5 deletions tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from gt4py.next.iterator import ir
from gt4py.next.iterator.pretty_parser import pparse
from gt4py.next.iterator.ir_utils import ir_makers as im


def test_symref():
Expand Down Expand Up @@ -41,14 +42,14 @@ def test_arithmetic():
ir.FunCall(
fun=ir.SymRef(id="plus"),
args=[
ir.Literal(value="1", type="int32"),
ir.Literal(value="2", type="int32"),
im.literal("1", "int32"),
im.literal("2", "int32"),
],
),
ir.Literal(value="3", type="int32"),
im.literal("3", "int32"),
],
),
ir.Literal(value="4", type="int32"),
im.literal("4", "int32"),
],
)
actual = pparse(testee)
Expand Down Expand Up @@ -115,7 +116,7 @@ def test_tuple_get():
testee = "x[42]"
expected = ir.FunCall(
fun=ir.SymRef(id="tuple_get"),
args=[ir.Literal(value="42", type=ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")],
args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")],
)
actual = pparse(testee)
assert actual == expected
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from gt4py.next.iterator import ir
from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat
from gt4py.next.iterator.ir_utils import ir_makers as im


def test_hmerge():
Expand Down Expand Up @@ -111,14 +112,14 @@ def test_arithmetic():
ir.FunCall(
fun=ir.SymRef(id="plus"),
args=[
ir.Literal(value="1", type="int64"),
ir.Literal(value="2", type="int64"),
im.literal("1", "int64"),
im.literal("2", "int64"),
],
),
ir.Literal(value="3", type="int64"),
im.literal("3", "int64"),
],
),
ir.Literal(value="4", type="int64"),
im.literal("4", "int64"),
],
)
expected = "(1 + 2) × 3 / 4"
Expand All @@ -132,11 +133,11 @@ def test_associativity():
args=[
ir.FunCall(
fun=ir.SymRef(id="plus"),
args=[ir.Literal(value="1", type="int64"), ir.Literal(value="2", type="int64")],
args=[im.literal("1", "int64"), im.literal("2", "int64")],
),
ir.FunCall(
fun=ir.SymRef(id="plus"),
args=[ir.Literal(value="3", type="int64"), ir.Literal(value="4", type="int64")],
args=[im.literal("3", "int64"), im.literal("4", "int64")],
),
],
)
Expand Down Expand Up @@ -204,7 +205,7 @@ def test_shift():
def test_tuple_get():
testee = ir.FunCall(
fun=ir.SymRef(id="tuple_get"),
args=[ir.Literal(value="42", type=ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")],
args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")],
)
expected = "x[42]"
actual = pformat(testee)
Expand Down
Loading

0 comments on commit 9f44142

Please sign in to comment.