Skip to content

Commit

Permalink
code optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Jan 10, 2024
1 parent f2547e5 commit b2b8fa7
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 49 deletions.
17 changes: 8 additions & 9 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,33 +453,32 @@ def _process_args(self, args: tuple, kwargs: dict):
) from err

full_args = [*args]
full_kwargs = list(kwargs.items())
full_kwargs = {**kwargs}
for index, param in enumerate(self.past_node.params):
if param.id in self.bound_args.keys():
if index < len(full_args):
full_args.insert(index, self.bound_args[param.id])
else:
pos = list(self.past_node.type.definition.pos_or_kw_args).index(str(param.id))
full_kwargs.insert(pos, (str(param.id), self.bound_args[param.id]))
full_kwargs[str(param.id)] = self.bound_args[param.id]

return super()._process_args(tuple(full_args), dict(full_kwargs))
return super()._process_args(tuple(full_args), full_kwargs)

@functools.cached_property
def itir(self):
new_itir = super().itir
for new_clos in new_itir.closures:
for key in self.bound_args.keys():
new_args = [ref(inp.id) for inp in new_clos.inputs]
for key, value in self.bound_args.items():
index = next(
index
for index, closure_input in enumerate(new_clos.inputs)
if closure_input.id == key
)
new_args[new_args.index(new_clos.inputs[index])] = promote_to_const_iterator(
literal_from_value(value)
)
new_clos.inputs.pop(index)
new_args = [ref(inp.id) for inp in new_clos.inputs]
params = [sym(inp.id) for inp in new_clos.inputs]
for key, value in self.bound_args.items():
pos = list(self.past_node.type.definition.pos_or_kw_args).index(key)
new_args.insert(pos, promote_to_const_iterator(literal_from_value(value)))
expr = itir.FunCall(
fun=new_clos.stencil,
args=new_args,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np

import gt4py.next as gtx
from gt4py.next import int32
from tests.next_tests.integration_tests import cases

from next_tests.integration_tests.cases import cartesian_case
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
fieldview_backend,
reduction_setup,
)


def test_with_bound_args(cartesian_case):
@gtx.field_operator
def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
if not condition:
scalar = 0
return a + a + scalar

@gtx.program
def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
fieldop_bound_args(a, scalar, condition, out=out)

a = cases.allocate(cartesian_case, program_bound_args, "a")()
scalar = int32(1)
ref = a + a + 1
out = cases.allocate(cartesian_case, program_bound_args, "out")()

prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)


def test_with_bound_args_order_args(cartesian_case):
@gtx.field_operator
def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField:
if not condition:
scalar = 0
return a + scalar

@gtx.program(backend=cartesian_case.backend)
def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField):
fieldop_args(a, condition, scalar, out=out)

a = cases.allocate(cartesian_case, program_args, "a")()
scalar = int32(1)
ref = a.asnumpy() + scalar
out = cases.allocate(cartesian_case, program_args, "out")()

prog_bounds = program_args.with_bound_args(condition=True)
prog_bounds(a=a, scalar=scalar, out=out, offset_provider={})
np.allclose(out.asnumpy(), ref)
Original file line number Diff line number Diff line change
Expand Up @@ -898,46 +898,6 @@ def test_docstring(a: cases.IField):
cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a)


def test_with_bound_args(cartesian_case):
@gtx.field_operator
def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField:
if not condition:
scalar = 0
return a + a + scalar

@gtx.program
def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField):
fieldop_bound_args(a, scalar, condition, out=out)

a = cases.allocate(cartesian_case, program_bound_args, "a")()
scalar = int32(1)
ref = a + a + 1
out = cases.allocate(cartesian_case, program_bound_args, "out")()

prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True)
cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref)


def test_with_bound_args_order_args(cartesian_case):
@gtx.field_operator
def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField:
if not condition:
scalar = 0
return a + scalar

@gtx.program(backend=cartesian_case.backend)
def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField):
fieldop_args(a, condition, scalar, out=out)

a = cases.allocate(cartesian_case, program_args, "a")()
scalar = int32(1)
ref = a.asnumpy() + scalar
out = cases.allocate(cartesian_case, program_args, "out")()

prog_bounds = program_args.with_bound_args(condition=True)
prog_bounds(a=a, scalar=scalar, out=out, offset_provider={})
np.allclose(out.asnumpy(), ref)

def test_domain(cartesian_case):
@gtx.field_operator
def fieldop_domain(a: cases.IField) -> cases.IField:
Expand Down

0 comments on commit b2b8fa7

Please sign in to comment.