Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next]: Bound args kwargs edit #1411

Merged
merged 6 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,27 +453,32 @@ def _process_args(self, args: tuple, kwargs: dict):
) from err

full_args = [*args]
full_kwargs = {**kwargs}
for index, param in enumerate(self.past_node.params):
if param.id in self.bound_args.keys():
full_args.insert(index, self.bound_args[param.id])
if index < len(full_args):
full_args.insert(index, self.bound_args[param.id])
else:
full_kwargs[str(param.id)] = self.bound_args[param.id]

return super()._process_args(tuple(full_args), kwargs)
return super()._process_args(tuple(full_args), full_kwargs)

@functools.cached_property
def itir(self):
new_itir = super().itir
for new_clos in new_itir.closures:
for key in self.bound_args.keys():
new_args = [ref(inp.id) for inp in new_clos.inputs]
for key, value in self.bound_args.items():
index = next(
index
for index, closure_input in enumerate(new_clos.inputs)
if closure_input.id == key
)
new_args[new_args.index(new_clos.inputs[index])] = promote_to_const_iterator(
literal_from_value(value)
)
new_clos.inputs.pop(index)
new_args = [ref(inp.id) for inp in new_clos.inputs]
params = [sym(inp.id) for inp in new_clos.inputs]
for value in self.bound_args.values():
new_args.append(promote_to_const_iterator(literal_from_value(value)))
expr = itir.FunCall(
fun=new_clos.stencil,
args=new_args,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np

import gt4py.next as gtx
from gt4py.next import int32

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import cartesian_case
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
fieldview_backend,
reduction_setup,
)


def test_with_bound_args(cartesian_case):
@gtx.field_operator
def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
if not condition:
scalar = 0
return a + scalar

@gtx.program
def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
fieldop_bound_args(a, scalar, condition, out=out)

a = cases.allocate(cartesian_case, program_bound_args, "a")()
scalar = int32(1)
ref = a + scalar
out = cases.allocate(cartesian_case, program_bound_args, "out")()

prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)


def test_with_bound_args_order_args(cartesian_case):
@gtx.field_operator
def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField:
scalar = 0 if not condition else scalar
return a + scalar

@gtx.program(backend=cartesian_case.backend)
def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField):
fieldop_args(a, condition, scalar, out=out)

a = cases.allocate(cartesian_case, program_args, "a")()
out = cases.allocate(cartesian_case, program_args, "out")()

prog_bounds = program_args.with_bound_args(condition=True)
prog_bounds(a=a, scalar=int32(1), out=out, offset_provider={})
np.allclose(out.asnumpy(), a.asnumpy() + int32(1))
Original file line number Diff line number Diff line change
Expand Up @@ -898,26 +898,6 @@ def test_docstring(a: cases.IField):
cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a)


def test_with_bound_args(cartesian_case):
@gtx.field_operator
def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
if not condition:
scalar = 0
return a + a + scalar

@gtx.program
def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
fieldop_bound_args(a, scalar, condition, out=out)

a = cases.allocate(cartesian_case, program_bound_args, "a")()
scalar = int32(1)
ref = a + a + 1
out = cases.allocate(cartesian_case, program_bound_args, "out")()

prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)


def test_domain(cartesian_case):
@gtx.field_operator
def fieldop_domain(a: cases.IField) -> cases.IField:
Expand Down