Skip to content

Commit

Permalink
[Fix] Add ctx to the original ndarray and revise the usage of context…
Browse files Browse the repository at this point in the history
… to ctx (apache#16819)

* try to fix warning

try to fix warning

try to fix all warnings

use ctx

* try to fix warnings

* try fo fix warnings
  • Loading branch information
sxjscience authored Nov 15, 2019
1 parent b972406 commit 4a27b5c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 22 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _gather_type_ctx_info(args):
Context of the first appeared NDArray (for backward-compatibility)
"""
if isinstance(args, NDArray):
return False, True, {args.context}, args.context
return False, True, {args.ctx}, args.ctx
elif isinstance(args, Symbol):
return True, False, set(), None
elif isinstance(args, (list, tuple)):
Expand Down Expand Up @@ -1141,7 +1141,7 @@ def forward(self, x, *args):
if len(ctx_set) > 1:
raise ValueError('Find multiple contexts in the input, '
'After hybridized, the HybridBlock only supports one input '
'context. You can print the ele.context in the '
'context. You can print the ele.ctx in the '
'input arguments to inspect their contexts. '
'Find all contexts = {}'.format(ctx_set))
with ctx:
Expand Down Expand Up @@ -1324,7 +1324,7 @@ def __init__(self, outputs, inputs, params=None):

def forward(self, x, *args):
if isinstance(x, NDArray):
with x.context:
with x.ctx:
return self._call_cached_op(x, *args)

assert isinstance(x, Symbol), \
Expand Down
8 changes: 4 additions & 4 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,10 @@ def _init_grad(self):
if self._grad_stype != 'default':
raise ValueError("mxnet.numpy.zeros does not support stype = {}"
.format(self._grad_stype))
self._grad = [_mx_np.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context)
self._grad = [_mx_np.zeros(shape=i.shape, dtype=i.dtype, ctx=i.ctx)
for i in self._data]
else:
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.ctx,
stype=self._grad_stype) for i in self._data]

autograd.mark_variables(self._check_and_get(self._data, list),
Expand Down Expand Up @@ -522,7 +522,7 @@ def row_sparse_data(self, row_id):
raise RuntimeError("Cannot return a copy of Parameter %s via row_sparse_data() " \
"because its storage type is %s. Please use data() instead." \
%(self.name, self._stype))
return self._get_row_sparse(self._data, row_id.context, row_id)
return self._get_row_sparse(self._data, row_id.ctx, row_id)

def list_row_sparse_data(self, row_id):
"""Returns copies of the 'row_sparse' parameter on all contexts, in the same order
Expand Down Expand Up @@ -897,7 +897,7 @@ def zero_grad(self):
if g.stype == 'row_sparse':
ndarray.zeros_like(g, out=g)
else:
arrays[g.context].append(g)
arrays[g.ctx].append(g)

if len(arrays) == 0:
return
Expand Down
37 changes: 27 additions & 10 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __repr__(self):
shape_info = 'x'.join(['%d' % x for x in self.shape])
return '\n%s\n<%s %s @%s>' % (str(self.asnumpy()),
self.__class__.__name__,
shape_info, self.context)
shape_info, self.ctx)

def __reduce__(self):
return NDArray, (None,), self.__getstate__()
Expand Down Expand Up @@ -729,14 +729,14 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None):
`squeeze_axes`: a sequence of axes to squeeze in the value array.
"""
if isinstance(value, numeric_types):
value_nd = full(bcast_shape, value, ctx=self.context, dtype=self.dtype)
value_nd = full(bcast_shape, value, ctx=self.ctx, dtype=self.dtype)
elif type(value) == self.__class__: # pylint: disable=unidiomatic-typecheck
value_nd = value.as_in_context(self.context)
value_nd = value.as_in_context(self.ctx)
if value_nd.dtype != self.dtype:
value_nd = value_nd.astype(self.dtype)
else:
try:
value_nd = array(value, ctx=self.context, dtype=self.dtype)
value_nd = array(value, ctx=self.ctx, dtype=self.dtype)
except:
raise TypeError('{} does not support assignment with non-array-like '
'object {} of type {}'.format(self.__class__, value, type(value)))
Expand Down Expand Up @@ -1220,7 +1220,7 @@ def _get_index_nd(self, key):

shape_nd_permut = tuple(self.shape[ax] for ax in axs_nd_permut)
converted_idcs_short = [
self._advanced_index_to_array(idx, ax_len, self.context)
self._advanced_index_to_array(idx, ax_len, self.ctx)
for idx, ax_len in zip(idcs_permut_short, shape_nd_permut)
]
bcast_idcs_permut_short = self._broadcast_advanced_indices(
Expand All @@ -1229,7 +1229,7 @@ def _get_index_nd(self, key):

# Get the ndim of advanced indexing subspace
converted_advanced_idcs = [
self._advanced_index_to_array(idx, ax_len, self.context)
self._advanced_index_to_array(idx, ax_len, self.ctx)
for idx, ax_len in zip(adv_idcs_nd, [self.shape[ax] for ax in adv_axs_nd])
]
bcast_advanced_shape = _broadcast_shapes(converted_advanced_idcs)
Expand Down Expand Up @@ -2433,6 +2433,23 @@ def context(self):
self.handle, ctypes.byref(dev_typeid), ctypes.byref(dev_id)))
return Context(Context.devtype2str[dev_typeid.value], dev_id.value)

@property
def ctx(self):
"""Device context of the array. Has the same meaning as context.
Examples
--------
>>> x = mx.nd.array([1, 2, 3, 4])
>>> x.ctx
cpu(0)
>>> type(x.ctx)
<class 'mxnet.context.Context'>
>>> y = mx.nd.zeros((2,3), mx.gpu(0))
>>> y.ctx
gpu(0)
"""
return self.context

@property
def dtype(self):
"""Data-type of the array's elements.
Expand Down Expand Up @@ -2580,7 +2597,7 @@ def astype(self, dtype, copy=True):
if not copy and np.dtype(dtype) == self.dtype:
return self

res = empty(self.shape, ctx=self.context, dtype=dtype)
res = empty(self.shape, ctx=self.ctx, dtype=dtype)
self.copyto(res)
return res

Expand Down Expand Up @@ -2646,7 +2663,7 @@ def copy(self):
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
"""
return self.copyto(self.context)
return self.copyto(self.ctx)

def slice_assign_scalar(self, value, begin, end, step):
"""
Expand Down Expand Up @@ -2904,7 +2921,7 @@ def _full(self, value):
"""
This is added as an NDArray class method in order to support polymorphism in NDArray and numpy.ndarray indexing
"""
return _internal._full(self.shape, value=value, ctx=self.context, dtype=self.dtype, out=self)
return _internal._full(self.shape, value=value, ctx=self.ctx, dtype=self.dtype, out=self)

def _scatter_set_nd(self, value_nd, indices):
"""
Expand Down Expand Up @@ -4542,7 +4559,7 @@ def concatenate(arrays, axis=0, always_copy=True):
assert shape_rest2 == arr.shape[axis+1:]
assert dtype == arr.dtype
ret_shape = shape_rest1 + (shape_axis,) + shape_rest2
ret = empty(ret_shape, ctx=arrays[0].context, dtype=dtype)
ret = empty(ret_shape, ctx=arrays[0].ctx, dtype=dtype)

idx = 0
begin = [0 for _ in ret_shape]
Expand Down
11 changes: 6 additions & 5 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,15 +921,15 @@ def __repr__(self):
elif dtype not in (_np.float32, _np.bool_):
array_str = array_str[:-1] + ', dtype={})'.format(dtype)

context = self.context
context = self.ctx
if context.device_type == 'cpu':
return array_str
return array_str[:-1] + ', ctx={})'.format(str(context))

def __str__(self):
"""Returns a string representation of the array."""
array_str = self.asnumpy().__str__()
context = self.context
context = self.ctx
if context.device_type == 'cpu' or self.ndim == 0:
return array_str
return '{array} @{ctx}'.format(array=array_str, ctx=context)
Expand Down Expand Up @@ -994,7 +994,7 @@ def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ,unused-ar
if not copy and _np.dtype(dtype) == self.dtype:
return self

res = empty(self.shape, dtype=dtype, ctx=self.context)
res = empty(self.shape, dtype=dtype, ctx=self.ctx)
self.copyto(res)
return res

Expand Down Expand Up @@ -1051,7 +1051,8 @@ def argmax(self, axis=None, out=None): # pylint: disable=arguments-differ

def as_in_context(self, context):
"""This function has been deprecated. Please refer to ``ndarray.as_in_ctx``."""
warnings.warn('ndarray.context has been renamed to ndarray.ctx', DeprecationWarning)
warnings.warn('ndarray.as_in_context has been renamed to'
' ndarray.as_in_ctx', DeprecationWarning)
return self.as_nd_ndarray().as_in_context(context).as_np_ndarray()

def as_in_ctx(self, ctx):
Expand Down Expand Up @@ -1864,7 +1865,7 @@ def _full(self, value):
Currently for internal use only. Implemented for __setitem__.
Assign to self an array of self's same shape and type, filled with value.
"""
return _mx_nd_np.full(self.shape, value, ctx=self.context, dtype=self.dtype, out=self)
return _mx_nd_np.full(self.shape, value, ctx=self.ctx, dtype=self.dtype, out=self)

# pylint: disable=redefined-outer-name
def _scatter_set_nd(self, value_nd, indices):
Expand Down

0 comments on commit 4a27b5c

Please sign in to comment.