Skip to content

Commit

Permalink
Add test for input arg with different domain
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Dec 11, 2024
1 parent aed4d1e commit f722c14
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ def program_domain(a: cases.IField, out: cases.IField):
a = cases.allocate(cartesian_case, program_domain, "a")()
out = cases.allocate(cartesian_case, program_domain, "out")()

ref = out.asnumpy().copy() # ensure we are not overwriting out outside of the domain
ref = out.asnumpy().copy() # ensure we are not overwriting out outside the domain
ref[1:9] = a.asnumpy()[1:9] * 2

cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField):
cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={})


def test_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
def test_out_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend)

size = cartesian_case.default_sizes[IDim]
Expand All @@ -266,3 +266,24 @@ def test_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
ref = inp.ndarray[1:-2]

cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref)

def test_in_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def):
@gtx.field_operator
def identity(a: cases.IField) -> cases.IField:
return a

@gtx.program
def copy_program(a: cases.IField, out: cases.IField):
identity(a, out=out, domain={IDim: (1, 9)})

inp = constructors.empty(
common.domain({IDim: (1, 9)}),
dtype=np.int32,
allocator=cartesian_case.allocator,
)
inp.ndarray[...] = 42
out = cases.allocate(cartesian_case, copy_program, "out", sizes={IDim: 10})()
ref = out.asnumpy().copy() # ensure we are not overwriting `out` outside the domain
ref[1:9] = inp.asnumpy()

cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref)

0 comments on commit f722c14

Please sign in to comment.