Skip to content

Commit

Permalink
back compat
Browse files Browse the repository at this point in the history
  • Loading branch information
rsokl committed Sep 7, 2024
1 parent 6f70e82 commit 6c8dc95
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
11 changes: 9 additions & 2 deletions src/mygrad/operation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np

from mygrad._numpy_version import NP_IS_V2
from mygrad._utils import SkipGradient, reduce_broadcast
from mygrad.errors import InvalidBackprop, InvalidGradient
from mygrad.typing import DTypeLike, Mask
Expand Down Expand Up @@ -97,7 +98,10 @@ def grad_post_process_fn(

if out.ndim == 0:
# sum-reduction to a scalar produces a float
out = np.asarray(out)
if NP_IS_V2:
out = np.asarray(out)
else: # pragma: no cover
np.array(out, copy=False)
return out

@abstractmethod
Expand Down Expand Up @@ -198,7 +202,10 @@ def backward(
f"numpy arrays, got a gradient of type: {type(backed_grad)}"
)

backed_grad = np.asarray(backed_grad)
if NP_IS_V2:
backed_grad = np.asarray(backed_grad)
else: # pragma: no cover
np.array(backed_grad, copy=False)

if self.where is not True:
backed_grad = backed_grad * self.where
Expand Down
33 changes: 21 additions & 12 deletions src/mygrad/tensor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import mygrad._utils.duplicating_graph as _dup
import mygrad._utils.graph_tracking as _track
import mygrad._utils.lock_management as _mem
from mygrad._numpy_version import NP_IS_V2
from mygrad._tensor_core_ops.indexing import GetItem, SetItem
from mygrad._utils import WeakRef, WeakRefIterable, collect_all_tensors_and_clear_grads
from mygrad.errors import DisconnectedView
Expand Down Expand Up @@ -738,7 +739,12 @@ def __array_function__(
def __array__(
self, dtype: DTypeLike = None, copy: Optional[bool] = None
) -> np.ndarray:
return np.asarray(self.data, dtype=dtype, copy=copy)
if NP_IS_V2:
return np.asarray(self.data, dtype=dtype, copy=copy)
else: # pragma: no cover
if copy is None:
copy = False
return np.array(self.data, dtype=dtype, copy=copy)

def __init__(
self,
Expand Down Expand Up @@ -798,18 +804,21 @@ def __init__(

self._creator: Optional[Operation] = _creator

if copy is False:
self.data = np.asarray(x, dtype=dtype) # type: np.ndarray
if not isinstance(ndmin, Integral):
raise TypeError(
f"'{type(ndmin)}' object cannot be interpreted as an integer"
)
if ndmin and self.data.ndim < ndmin:
self.data = self.data[(*(None for _ in range(ndmin - self.data.ndim)),)]
if not NP_IS_V2:
self.data = np.array(x, dtype=dtype, copy=copy, ndmin=ndmin)
else:
self.data = np.array(
x, dtype=dtype, copy=copy, ndmin=ndmin
) # type: np.ndarray
if copy is False:
self.data = np.asarray(x, dtype=dtype)
if not isinstance(ndmin, Integral):
raise TypeError(
f"'{type(ndmin)}' object cannot be interpreted as an integer"
)
if ndmin and self.data.ndim < ndmin:
self.data = self.data[
(*(None for _ in range(ndmin - self.data.ndim)),)
]
else:
self.data = np.array(x, dtype=dtype, copy=copy, ndmin=ndmin)

dtype = self.data.dtype.type
is_float = issubclass(dtype, np.floating) # faster than `numpy.issubdtype`
Expand Down

0 comments on commit 6c8dc95

Please sign in to comment.