Skip to content

Commit

Permalink
bug[next]: Bound args kwargs edit (#1411)
Browse files Browse the repository at this point in the history
* edits for BoundArgs with kwargs in correct order
  • Loading branch information
nfarabullini authored Jan 18, 2024
1 parent 6283ac9 commit 3edf21e
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 26 deletions.
17 changes: 11 additions & 6 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,27 +453,32 @@ def _process_args(self, args: tuple, kwargs: dict):
) from err

full_args = [*args]
full_kwargs = {**kwargs}
for index, param in enumerate(self.past_node.params):
if param.id in self.bound_args.keys():
full_args.insert(index, self.bound_args[param.id])
if index < len(full_args):
full_args.insert(index, self.bound_args[param.id])
else:
full_kwargs[str(param.id)] = self.bound_args[param.id]

return super()._process_args(tuple(full_args), kwargs)
return super()._process_args(tuple(full_args), full_kwargs)

@functools.cached_property
def itir(self):
new_itir = super().itir
for new_clos in new_itir.closures:
for key in self.bound_args.keys():
new_args = [ref(inp.id) for inp in new_clos.inputs]
for key, value in self.bound_args.items():
index = next(
index
for index, closure_input in enumerate(new_clos.inputs)
if closure_input.id == key
)
new_args[new_args.index(new_clos.inputs[index])] = promote_to_const_iterator(
literal_from_value(value)
)
new_clos.inputs.pop(index)
new_args = [ref(inp.id) for inp in new_clos.inputs]
params = [sym(inp.id) for inp in new_clos.inputs]
for value in self.bound_args.values():
new_args.append(promote_to_const_iterator(literal_from_value(value)))
expr = itir.FunCall(
fun=new_clos.stencil,
args=new_args,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np

import gt4py.next as gtx
from gt4py.next import int32

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import cartesian_case
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
fieldview_backend,
reduction_setup,
)


def test_with_bound_args(cartesian_case):
@gtx.field_operator
def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
if not condition:
scalar = 0
return a + scalar

@gtx.program
def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
fieldop_bound_args(a, scalar, condition, out=out)

a = cases.allocate(cartesian_case, program_bound_args, "a")()
scalar = int32(1)
ref = a + scalar
out = cases.allocate(cartesian_case, program_bound_args, "out")()

prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)


def test_with_bound_args_order_args(cartesian_case):
@gtx.field_operator
def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField:
scalar = 0 if not condition else scalar
return a + scalar

@gtx.program(backend=cartesian_case.backend)
def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField):
fieldop_args(a, condition, scalar, out=out)

a = cases.allocate(cartesian_case, program_args, "a")()
out = cases.allocate(cartesian_case, program_args, "out")()

prog_bounds = program_args.with_bound_args(condition=True)
prog_bounds(a=a, scalar=int32(1), out=out, offset_provider={})
np.allclose(out.asnumpy(), a.asnumpy() + int32(1))
Original file line number Diff line number Diff line change
Expand Up @@ -898,26 +898,6 @@ def test_docstring(a: cases.IField):
cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a)


def test_with_bound_args(cartesian_case):
@gtx.field_operator
def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
if not condition:
scalar = 0
return a + a + scalar

@gtx.program
def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
fieldop_bound_args(a, scalar, condition, out=out)

a = cases.allocate(cartesian_case, program_bound_args, "a")()
scalar = int32(1)
ref = a + a + 1
out = cases.allocate(cartesian_case, program_bound_args, "out")()

prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)


def test_domain(cartesian_case):
@gtx.field_operator
def fieldop_domain(a: cases.IField) -> cases.IField:
Expand Down

0 comments on commit 3edf21e

Please sign in to comment.