Skip to content

Commit

Permalink
Rely on __cuda_array_interface__ for tensor conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
Emilio Castillo committed Jun 15, 2020
1 parent 566e681 commit f6c4809
Showing 1 changed file with 3 additions and 30 deletions.
33 changes: 3 additions & 30 deletions chainer_pytorch_migration/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,7 @@ def asarray(tensor):
return cupy.ndarray(
tuple(tensor.shape),
dtype=to_numpy_dtype(tensor.dtype))
itemsize = tensor.element_size()
storage = tensor.storage()
memptr = cupy.cuda.MemoryPointer(
cupy.cuda.UnownedMemory(
storage.data_ptr(), storage.size() * itemsize, tensor,
),
tensor.storage_offset() * itemsize,
)
return cupy.ndarray(
tuple(tensor.shape),
dtype=to_numpy_dtype(tensor.dtype),
memptr=memptr,
strides=tuple(s * itemsize for s in tensor.stride()),
)
return cupy.array(tensor.detach(), copy=False)
if dev_type == 'cpu':
return tensor.detach().numpy()
raise ValueError('tensor on device "{}" is not supported', dev_type)
Expand All @@ -59,7 +46,7 @@ def astensor(array):
view are gone.
Note:
If the array has negative strides, a copy is made
aaIf the array has negative strides, a copy is made
"""
if array is None:
raise TypeError('array cannot be None')
Expand All @@ -78,28 +65,14 @@ def astensor(array):
device=array.device.id
)
return torch.as_tensor(
_ArrayWithCudaArrayInterfaceHavingStrides(array),
array,
device=array.device.id,
)
if isinstance(array, numpy.ndarray):
return torch.from_numpy(array)
raise TypeError('array of type {} is not supported'.format(type(array)))


# Workaround to avoid a bug in converting cupy.ndarray to torch.Tensor via
# __cuda_array_interface__. See: https://github.com/pytorch/pytorch/pull/24947
class _ArrayWithCudaArrayInterfaceHavingStrides:

def __init__(self, array):
self._array = array

@property
def __cuda_array_interface__(self):
d = self._array.__cuda_array_interface__
d['strides'] = self._array.strides
return d


def to_numpy_dtype(torch_dtype):
"""Convert PyTorch dtype to NumPy dtype.
Expand Down

0 comments on commit f6c4809

Please sign in to comment.