Skip to content

Commit

Permalink
do not support transfer to same type
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Sep 25, 2024
1 parent 2c0164d commit 216ac2d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
40 changes: 21 additions & 19 deletions arraycontext/impl/pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,17 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
import numpy as np

import arraycontext.impl.pyopencl.taggable_cl_array as tga
if isinstance(expr.data, np.ndarray):
data = tga.to_device(self.queue, expr.data, allocator=self.allocator)
return DataWrapper(
data=data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)
if not isinstance(expr.data, np.ndarray):
raise ValueError("TransferToDeviceMapper: tried to transfer data that "
"is already on the device")

return super().map_data_wrapper(expr)
data = tga.to_device(self.queue, expr.data, allocator=self.allocator)
return DataWrapper(
data=data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)


class TransferToHostMapper(CopyMapper):
Expand All @@ -159,16 +160,17 @@ def __init__(self, queue: CommandQueue, allocator: AllocatorBase = None) -> None

def map_data_wrapper(self, expr: DataWrapper) -> Array:
import arraycontext.impl.pyopencl.taggable_cl_array as tga
if isinstance(expr.data, tga.TaggableCLArray):
data = expr.data.get()
return DataWrapper(
data=data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

return super().map_data_wrapper(expr)
if not isinstance(expr.data, tga.TaggableCLArray):
raise ValueError("TransferToHostMapper: tried to transfer data that "
"is already on the host")

data = expr.data.get()
return DataWrapper(
data=data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)


def transfer_to_device(expr: ArrayOrNames, queue: CommandQueue,
Expand Down
10 changes: 4 additions & 6 deletions test/test_pytato_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,19 +211,17 @@ def test_transfer(actx_factory):
assert ah != a
assert isinstance(ah.data, np.ndarray)

ahh = transfer_to_host(ah, actx.queue, actx.allocator)
assert isinstance(ahh.data, np.ndarray)
assert ah != ahh # copied DataWrappers compare unequal
assert ah != a
with pytest.raises(ValueError):
_ahh = transfer_to_host(ah, actx.queue, actx.allocator)

ad = transfer_to_device(ah, actx.queue, actx.allocator)
assert isinstance(ad.data, TaggableCLArray)
assert ad != ah
assert ad != a # copied DataWrappers compare unequal
assert np.array_equal(a.data.get(), ad.data.get())

add = transfer_to_device(ad, actx.queue, actx.allocator)
assert add != ad # copied DataWrappers compare unequal
with pytest.raises(ValueError):
_add = transfer_to_device(ad, actx.queue, actx.allocator)

# }}}

Expand Down

0 comments on commit 216ac2d

Please sign in to comment.