Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Nov 15, 2023
1 parent 2444e1f commit 290e63e
Showing 1 changed file with 9 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest

import gt4py.next as gtx
from gt4py.next import common
from gt4py.next.program_processors.runners import dace_iterator, gtfn

from next_tests.integration_tests import cases
Expand All @@ -26,26 +27,19 @@

@pytest.mark.requires_gpu
@pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace_gpu, gtfn.run_gtfn_gpu])
def test_copy(cartesian_case, fieldview_backend): # noqa: F811 # fixtures
def test_copy(fieldview_backend): # noqa: F811 # fixtures
import cupy as cp

@gtx.field_operator(backend=fieldview_backend)
def testee(a: cases.IJKField) -> cases.IJKField:
return a

inp_arr = cp.full(shape=(3, 4, 5), fill_value=3, dtype=cp.int32)
outp_arr = cp.zeros_like(inp_arr)
inp = gtx.as_field([cases.IDim, cases.JDim, cases.KDim], inp_arr)
outp = gtx.as_field([cases.IDim, cases.JDim, cases.KDim], outp_arr)

testee(inp, out=outp, offset_provider={})
assert cp.allclose(inp_arr, outp_arr)

inp_field = gtx.full(
[cases.IDim, cases.JDim, cases.KDim], fill_value=3, allocator=fieldview_backend
)
out_field = gtx.zeros(
[cases.IDim, cases.JDim, cases.KDim], outp_arr, allocator=fieldview_backend
)
domain = {
cases.IDim: common.unit_range(3),
cases.JDim: common.unit_range(4),
cases.KDim: common.unit_range(5),
}
inp_field = gtx.full(domain, fill_value=3, allocator=fieldview_backend, dtype=cp.int32)
out_field = gtx.zeros(domain, allocator=fieldview_backend, dtype=cp.int32)
testee(inp_field, out=out_field, offset_provider={})
assert cp.allclose(inp_field.ndarray, out_field.ndarray)

0 comments on commit 290e63e

Please sign in to comment.