Skip to content

Commit

Permalink
Revert "feat[next]: Extend astype to work with tuples (#1352)"
Browse files Browse the repository at this point in the history
This reverts commit 67a6188.
  • Loading branch information
ninaburg authored Nov 17, 2023
1 parent 67a6188 commit e8584ae
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 114 deletions.
6 changes: 1 addition & 5 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,7 @@ def where(


@builtin_function
def astype(
field: Field | gt4py_defs.ScalarT | Tuple[Field, ...],
type_: type,
/,
) -> Field | Tuple[Field, ...]:
def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field:
raise NotImplementedError()


Expand Down
11 changes: 3 additions & 8 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,12 +823,10 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call:
return self._visit_reduction(node, **kwargs)

def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
return_type: ts.TupleType | ts.ScalarType | ts.FieldType
value, new_type = node.args
assert isinstance(
value.type, (ts.FieldType, ts.ScalarType, ts.TupleType)
value.type, (ts.FieldType, ts.ScalarType)
) # already checked using generic mechanism

if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [
kind.name for kind in ts.ScalarKind
]:
Expand All @@ -837,11 +835,8 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.",
)

return_type = type_info.apply_to_primitive_constituents(
value.type,
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
return_type = with_altered_scalar_kind(
value.type, getattr(ts.ScalarKind, new_type.id.upper())
)

return foast.Call(
Expand Down
35 changes: 6 additions & 29 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,12 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:

def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, new_type = node.args[0], node.args[1].id
return self._process_elements(
lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs
obj, dtype = node.args[0], node.args[1].id

# TODO check that we test astype that results in a itir.map_ operation
return self._map(
im.lambda_("it")(im.call("cast_")("it", str(dtype))),
obj,
)

def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall:
Expand Down Expand Up @@ -400,32 +403,6 @@ def _map(self, op, *args, **kwargs):

return im.promote_to_lifted_stencil(im.call(op))(*lowered_args)

def _process_elements(
self,
process_func: Callable[[itir.Expr], itir.Expr],
obj: foast.Expr,
current_el_type: ts.TypeSpec,
current_el_expr: itir.Expr = im.ref("expr"),
):
"""Recursively applies a processing function to all primitive constituents of a tuple."""
if isinstance(current_el_type, ts.TupleType):
# TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element.
return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))(
*[
self._process_elements(
process_func,
obj,
current_el_type.types[i],
im.tuple_get(i, current_el_expr),
)
for i in range(len(current_el_type.types))
]
)
elif type_info.contains_local_field(current_el_type):
raise NotImplementedError("Processing fields with local dimension is not implemented.")
else:
return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj)


class FieldOperatorLoweringError(Exception):
...
Original file line number Diff line number Diff line change
Expand Up @@ -325,76 +325,6 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
)


@pytest.mark.uses_tuple_returns
def test_astype_on_tuples(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def field_op_returning_a_tuple(
a: cases.IFloatField, b: cases.IFloatField
) -> tuple[gtx.Field[[IDim], float], gtx.Field[[IDim], float]]:
tup = (a, b)
return tup

@gtx.field_operator
def cast_tuple(
a: cases.IFloatField,
b: cases.IFloatField,
a_casted_to_int_outside_of_gt4py: cases.IField,
b_casted_to_int_outside_of_gt4py: cases.IField,
) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]:
result = astype(field_op_returning_a_tuple(a, b), int32)
return (
result[0] == a_casted_to_int_outside_of_gt4py,
result[1] == b_casted_to_int_outside_of_gt4py,
)

@gtx.field_operator
def cast_nested_tuple(
a: cases.IFloatField,
b: cases.IFloatField,
a_casted_to_int_outside_of_gt4py: cases.IField,
b_casted_to_int_outside_of_gt4py: cases.IField,
) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]:
result = astype((a, field_op_returning_a_tuple(a, b)), int32)
return (
result[0] == a_casted_to_int_outside_of_gt4py,
result[1][0] == a_casted_to_int_outside_of_gt4py,
result[1][1] == b_casted_to_int_outside_of_gt4py,
)

a = cases.allocate(cartesian_case, cast_tuple, "a")()
b = cases.allocate(cartesian_case, cast_tuple, "b")()
a_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32))
b_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32))
out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)()
out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)()

cases.verify(
cartesian_case,
cast_tuple,
a,
b,
a_casted_to_int_outside_of_gt4py,
b_casted_to_int_outside_of_gt4py,
out=out_tuple,
ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)),
)

cases.verify(
cartesian_case,
cast_nested_tuple,
a,
b,
a_casted_to_int_outside_of_gt4py,
b_casted_to_int_outside_of_gt4py,
out=out_nested_tuple,
ref=(
np.full_like(a, True, dtype=bool),
np.full_like(a, True, dtype=bool),
np.full_like(b, True, dtype=bool),
),
)


def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,8 @@ def simple_astype(a: Field[[TDim], float64]):

def test_astype_wrong_value_type():
def simple_astype(a: Field[[TDim], float64]):
# we just use broadcast here but anything that is not a field, scalar or tuple thereof works
return astype(broadcast, bool)
# we just use a tuple here but anything that is not a field or scalar works
return astype((1, 2), bool)

with pytest.raises(errors.DSLError) as exc_info:
_ = FieldOperatorParser.apply_to_function(simple_astype)
Expand Down

0 comments on commit e8584ae

Please sign in to comment.