diff --git a/src/mygrad/operation_base.py b/src/mygrad/operation_base.py index 3be681d1..655d3bc2 100644 --- a/src/mygrad/operation_base.py +++ b/src/mygrad/operation_base.py @@ -8,6 +8,7 @@ import numpy as np +from mygrad._numpy_version import NP_IS_V2 from mygrad._utils import SkipGradient, reduce_broadcast from mygrad.errors import InvalidBackprop, InvalidGradient from mygrad.typing import DTypeLike, Mask @@ -97,7 +98,10 @@ def grad_post_process_fn( if out.ndim == 0: # sum-reduction to a scalar produces a float - out = np.asarray(out) + if NP_IS_V2: + out = np.asarray(out) + else: # pragma: no cover + np.array(out, copy=False) return out @abstractmethod @@ -198,7 +202,10 @@ def backward( f"numpy arrays, got a gradient of type: {type(backed_grad)}" ) - backed_grad = np.asarray(backed_grad) + if NP_IS_V2: + backed_grad = np.asarray(backed_grad) + else: # pragma: no cover + np.array(backed_grad, copy=False) if self.where is not True: backed_grad = backed_grad * self.where diff --git a/src/mygrad/tensor_base.py b/src/mygrad/tensor_base.py index e06095f3..d3d8cbe7 100644 --- a/src/mygrad/tensor_base.py +++ b/src/mygrad/tensor_base.py @@ -29,6 +29,7 @@ import mygrad._utils.duplicating_graph as _dup import mygrad._utils.graph_tracking as _track import mygrad._utils.lock_management as _mem +from mygrad._numpy_version import NP_IS_V2 from mygrad._tensor_core_ops.indexing import GetItem, SetItem from mygrad._utils import WeakRef, WeakRefIterable, collect_all_tensors_and_clear_grads from mygrad.errors import DisconnectedView @@ -738,7 +739,12 @@ def __array_function__( def __array__( self, dtype: DTypeLike = None, copy: Optional[bool] = None ) -> np.ndarray: - return np.asarray(self.data, dtype=dtype, copy=copy) + if NP_IS_V2: + return np.asarray(self.data, dtype=dtype, copy=copy) + else: # pragma: no cover + if copy is None: + copy = False + return np.array(self.data, dtype=dtype, copy=copy) def __init__( self, @@ -798,18 +804,21 @@ def __init__( self._creator: Optional[Operation] = _creator - if copy is False: - self.data = np.asarray(x, dtype=dtype) # type: np.ndarray - if not isinstance(ndmin, Integral): - raise TypeError( - f"'{type(ndmin)}' object cannot be interpreted as an integer" - ) - if ndmin and self.data.ndim < ndmin: - self.data = self.data[(*(None for _ in range(ndmin - self.data.ndim)),)] + if not NP_IS_V2: + self.data = np.array(x, dtype=dtype, copy=copy, ndmin=ndmin) else: - self.data = np.array( - x, dtype=dtype, copy=copy, ndmin=ndmin - ) # type: np.ndarray + if copy is False: + self.data = np.asarray(x, dtype=dtype) + if not isinstance(ndmin, Integral): + raise TypeError( + f"'{type(ndmin)}' object cannot be interpreted as an integer" + ) + if ndmin and self.data.ndim < ndmin: + self.data = self.data[ + (*(None for _ in range(ndmin - self.data.ndim)),) + ] + else: + self.data = np.array(x, dtype=dtype, copy=copy, ndmin=ndmin) dtype = self.data.dtype.type is_float = issubclass(dtype, np.floating) # faster than `numpy.issubdtype`