Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
matthiasdiener committed Dec 3, 2024
1 parent 9c56443 commit fd95813
Showing 3 changed files with 29 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -56,6 +56,9 @@ jobs:
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
echo "- cupy" >> "$CONDA_ENVIRONMENT"
build_py_project_in_conda_env
python -m pip install mypy pytest
./run-mypy.sh
21 changes: 20 additions & 1 deletion arraycontext/impl/cupy/__init__.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,9 @@
THE SOFTWARE.
"""

from typing import Any
from typing import Any, overload

import numpy as np

import loopy as lp
from pytools.tag import ToTagSetConvertible
@@ -44,6 +46,7 @@
ArrayContext,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ContainerOrScalarT,
NumpyOrContainerOrScalar,
UntransformedCodeWarning,
)
@@ -83,12 +86,28 @@ def _get_fake_numpy_namespace(self):
def clone(self):
return type(self)()

@overload
def from_numpy(self, array: np.ndarray) -> Array:
...

@overload
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...

def from_numpy(self,
array: NumpyOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
import cupy as cp
return cp.array(array)

@overload
def to_numpy(self, array: Array) -> np.ndarray:
...

@overload
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...

def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
9 changes: 6 additions & 3 deletions arraycontext/impl/cupy/fake_numpy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


__copyright__ = """
Copyright (C) 2024 University of Illinois Board of Trustees
"""
@@ -164,7 +167,7 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
[(true_ary if kx_i == ky_i else false_ary)
and self.array_equal(x_i, y_i)
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y)],
in zip(serialized_x, serialized_y, strict=True)],
true_ary)

def arange(self, *args, **kwargs):
@@ -176,14 +179,14 @@ def linspace(self, *args, **kwargs):
return cp.linspace(*args, **kwargs)

def zeros_like(self, ary):
if isinstance(ary, (int, float, complex)):
if isinstance(ary, int | float | complex):
import cupy as cp
# Cupy does not support zeros_like with scalar arguments
ary = cp.array(ary)
return rec_map_array_container(cp.zeros_like, ary)

def ones_like(self, ary):
if isinstance(ary, (int, float, complex)):
if isinstance(ary, int | float | complex):
import cupy as cp
# Cupy does not support ones_like with scalar arguments
ary = cp.array(ary)

0 comments on commit fd95813

Please sign in to comment.