Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] CSR/ CSC Elemwise #465

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
20 changes: 14 additions & 6 deletions sparse/_compressed/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
can_store,
check_zero_fill_value,
check_compressed_axes,
_zero_of_dtype,
equivalent,
)
from .._coo.core import COO
Expand Down Expand Up @@ -143,7 +144,7 @@ def __init__(
shape=None,
compressed_axes=None,
prune=False,
fill_value=0,
fill_value=None,
idx_dtype=None,
):
if isinstance(arg, ss.spmatrix):
Expand All @@ -169,6 +170,10 @@ def __init__(
arg.fill_value,
)

self.data, self.indices, self.indptr = arg

if fill_value is None:
fill_value = _zero_of_dtype(self.data.dtype)
if shape is None:
raise ValueError("missing `shape` argument")

Expand All @@ -177,8 +182,6 @@ def __init__(
if len(shape) == 1:
compressed_axes = None

self.data, self.indices, self.indptr = arg

if self.data.ndim != 1:
raise ValueError("data must be a scalar or 1-dimensional.")

Expand Down Expand Up @@ -845,7 +848,12 @@ def _prune(self):

class _Compressed2d(GCXS):
def __init__(
self, arg, shape=None, compressed_axes=None, prune=False, fill_value=0
self,
arg,
shape=None,
compressed_axes=None,
prune=False,
fill_value=None,
):
if not hasattr(arg, "shape") and shape is None:
raise ValueError("missing `shape` argument")
Expand Down Expand Up @@ -888,7 +896,7 @@ class CSR(_Compressed2d):
Sparse supports 2-D CSR.
"""

def __init__(self, arg, shape=None, prune=False, fill_value=0):
def __init__(self, arg, shape=None, prune=False, fill_value=None):
super().__init__(arg, shape=shape, compressed_axes=(0,), fill_value=fill_value)

@classmethod
Expand All @@ -913,7 +921,7 @@ class CSC(_Compressed2d):
Sparse supports 2-D CSC.
"""

def __init__(self, arg, shape=None, prune=False, fill_value=0):
def __init__(self, arg, shape=None, prune=False, fill_value=None):
super().__init__(arg, shape=shape, compressed_axes=(1,), fill_value=fill_value)

@classmethod
Expand Down
155 changes: 155 additions & 0 deletions sparse/_compressed/elemwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from functools import lru_cache
from typing import Callable

import numpy as np
import scipy.sparse
from numba import njit

from .compressed import _Compressed2d


def op_unary(func, a):
res = a.copy()
res.data = func(a.data)
return res

Check warning on line 14 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L12-L14

Added lines #L12 - L14 were not covered by tests


@lru_cache(maxsize=None)
def _numba_d(func):
return njit(lambda *x: func(*x))


def binary_op(func, a, b):
func = _numba_d(func)
if isinstance(a, _Compressed2d) and isinstance(b, _Compressed2d):
return op_union_indices(func, a, b)
else:
raise NotImplementedError()

# From scipy._util
def _prune_array(array):
"""Return an array equivalent to the input array. If the input
array is a view of a much larger array, copy its contents to a
newly allocated array. Otherwise, return the input unchanged.
"""
if array.base is not None and array.size < array.base.size // 2:
return array.copy()
return array



def op_union_indices(
op: Callable, a: scipy.sparse.csr_matrix, b: scipy.sparse.csr_matrix, *, default_value=0
):
assert a.shape == b.shape

if type(a) != type(b):
b = type(a)(b)

Check warning on line 47 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L47

Added line #L47 was not covered by tests
# a.sort_indices()
# b.sort_indices()

# TODO: numpy is weird with bools here
out_dtype = np.array(op(a.data[0], b.data[0])).dtype
default_value = out_dtype.type(default_value)
out_indptr = np.zeros_like(a.indptr)
out_indices = np.zeros(len(a.indices) + len(b.indices), dtype=np.promote_types(a.indices.dtype, b.indices.dtype))
out_data = np.zeros(len(out_indices), dtype=out_dtype)

nnz = op_union_indices_csr_csr(
op,
a.indptr,
a.indices,
a.data,
b.indptr,
b.indices,
b.data,
out_indptr,
out_indices,
out_data,
out_dtype=out_dtype,
default_value=default_value,
)
out_data = _prune_array(out_data[:nnz])
out_indices = _prune_array(out_indices[:nnz])
return type(a)((out_data, out_indices, out_indptr), shape=a.shape)


@njit
def op_union_indices_csr_csr(
op: Callable,
a_indptr: np.ndarray,
a_indices: np.ndarray,
a_data: np.ndarray,
b_indptr: np.ndarray,
b_indices: np.ndarray,
b_data: np.ndarray,
out_indptr: np.ndarray,
out_indices: np.ndarray,
out_data: np.ndarray,
out_dtype,
default_value,
):
# out_indptr = np.zeros_like(a_indptr)
# out_indices = np.zeros(len(a_indices) + len(b_indices), dtype=a_indices.dtype)
# out_data = np.zeros(len(out_indices), dtype=out_dtype)

out_idx = 0

Check warning on line 96 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L96

Added line #L96 was not covered by tests

for i in range(len(a_indptr) - 1):

Check warning on line 98 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L98

Added line #L98 was not covered by tests

a_idx = a_indptr[i]
a_end = a_indptr[i + 1]
b_idx = b_indptr[i]
b_end = b_indptr[i + 1]

Check warning on line 103 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L100-L103

Added lines #L100 - L103 were not covered by tests

while (a_idx < a_end) and (b_idx < b_end):
a_j = a_indices[a_idx]
b_j = b_indices[b_idx]
if a_j < b_j:
val = op(a_data[a_idx], default_value)
if val != default_value:
out_indices[out_idx] = a_j
out_data[out_idx] = val
out_idx += 1
a_idx += 1
elif b_j < a_j:
val = op(default_value, b_data[b_idx])
if val != default_value:
out_indices[out_idx] = b_j
out_data[out_idx] = val
out_idx += 1
b_idx += 1

Check warning on line 121 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L105-L121

Added lines #L105 - L121 were not covered by tests
else:
val = op(a_data[a_idx], b_data[b_idx])
if val != default_value:
out_indices[out_idx] = a_j
out_data[out_idx] = val
out_idx += 1
a_idx += 1
b_idx += 1

Check warning on line 129 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L123-L129

Added lines #L123 - L129 were not covered by tests

# Catch up the other set
while a_idx < a_end:
val = op(a_data[a_idx], default_value)
if val != default_value:
out_indices[out_idx] = a_indices[a_idx]
out_data[out_idx] = val
out_idx += 1
a_idx += 1

Check warning on line 138 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L132-L138

Added lines #L132 - L138 were not covered by tests

while b_idx < b_end:
val = op(default_value, b_data[b_idx])
if val != default_value:
out_indices[out_idx] = b_indices[b_idx]
out_data[out_idx] = val
out_idx += 1
b_idx += 1

Check warning on line 146 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L140-L146

Added lines #L140 - L146 were not covered by tests

out_indptr[i + 1] = out_idx

Check warning on line 148 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L148

Added line #L148 was not covered by tests

# This may need to change to be "resize" to allow memory reallocation
# resize is currently not implemented in numba
out_indices = out_indices[: out_idx]
out_data = out_data[: out_idx]

Check warning on line 153 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L152-L153

Added lines #L152 - L153 were not covered by tests

return out_idx

Check warning on line 155 in sparse/_compressed/elemwise.py

View check run for this annotation

Codecov / codecov/patch

sparse/_compressed/elemwise.py#L155

Added line #L155 was not covered by tests
110 changes: 92 additions & 18 deletions sparse/_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,48 @@
)


# TODO: Figure out the right way to type this
# TODO: Figure out how to do 1d COO + CSR or CSC
def _resolve_result_type(args: "list[ArrayLike]") -> "Type":
from ._compressed import GCXS, CSR, CSC
from ._coo import COO
from ._dok import DOK
from ._sparse_array import SparseArray
from ._compressed.compressed import _Compressed2d

args = [arg for arg in args if isinstance(arg, SparseArray)]

if all(isinstance(arg, DOK) for arg in args):
out_type = DOK
elif all(isinstance(arg, CSR) for arg in args):
out_type = CSR
elif all(isinstance(arg, CSC) for arg in args):
out_type = CSC
elif all(isinstance(arg, _Compressed2d) for arg in args):
out_type = CSR
elif all(isinstance(arg, GCXS) for arg in args):
out_type = GCXS
else:
out_type = COO
return out_type


def _from_scipy_sparse(a):
from ._compressed import CSR, CSC
from ._coo import COO
from ._dok import DOK

assert isinstance(a, scipy.sparse.spmatrix)
if isinstance(a, scipy.sparse.csr_matrix):
return CSR(a)
elif isinstance(a, scipy.sparse.csc_matrix):
return CSC(a)
elif isinstance(a, scipy.sparse.dok_matrix):
return DOK(a.shape, data=dict(a))
else:
return COO(a)


class _Elemwise:
def __init__(self, func, *args, **kwargs):
"""
Expand All @@ -423,24 +465,26 @@
"""
from ._coo import COO
from ._sparse_array import SparseArray
from ._compressed import GCXS
from ._compressed import GCXS, CSR, CSC
from ._compressed.compressed import _Compressed2d
from ._dok import DOK

processed_args = []
out_type = GCXS

sparse_args = [arg for arg in args if isinstance(arg, SparseArray)]
args = [
arg
if not isinstance(arg, scipy.sparse.spmatrix)
else _from_scipy_sparse(arg)
for arg in args
]

if all(isinstance(arg, DOK) for arg in sparse_args):
out_type = DOK
elif all(isinstance(arg, GCXS) for arg in sparse_args):
out_type = GCXS
else:
out_type = COO
processed_args = []

self.out_type = _resolve_result_type(args)
# Should this happen before dispatch?
# Hmm, this may need major major changes.
# Case to consider: CSR or CSC + 1d COO
for arg in args:
if isinstance(arg, scipy.sparse.spmatrix):
processed_args.append(COO.from_scipy_sparse(arg))
if self.out_type != COO and isinstance(arg, _Compressed2d):
processed_args.append(arg)
elif isscalar(arg) or isinstance(arg, np.ndarray):
# Faster and more reliable to pass ()-shaped ndarrays as scalars.
processed_args.append(np.asarray(arg))
Expand All @@ -454,7 +498,6 @@
self.args = None
return

self.out_type = out_type
self.args = tuple(processed_args)
self.func = func
self.dtype = kwargs.pop("dtype", None)
Expand All @@ -467,14 +510,19 @@

def get_result(self):
from ._coo import COO
from ._sparse_array import SparseArray
from ._compressed.compressed import _Compressed2d

if self.args is None:
return NotImplemented

if self._dense_result:
args = [a.todense() if isinstance(a, COO) else a for a in self.args]
args = [a.todense() if isinstance(a, SparseArray) else a for a in self.args]
return self.func(*args, **self.kwargs)

if issubclass(self.out_type, _Compressed2d):
return self._get_result_compressed_2d()

if any(s == 0 for s in self.shape):
data = np.empty((0,), dtype=self.fill_value.dtype)
coords = np.empty((0, len(self.shape)), dtype=np.intp)
Expand Down Expand Up @@ -521,6 +569,29 @@
fill_value=self.fill_value,
).asformat(self.out_type)

def _get_result_compressed_2d(self):
from ._compressed import elemwise as elemwise2d
from ._compressed.compressed import _Compressed2d

if len(self.args) == 1:
result = elemwise2d.op_unary(self.func, self.args[0])

Check warning on line 577 in sparse/_umath.py

View check run for this annotation

Codecov / codecov/patch

sparse/_umath.py#L577

Added line #L577 was not covered by tests

processed_args = []
for arg in self.args:
if isinstance(arg, self.out_type):
processed_args.append(arg)
elif isinstance(arg, _Compressed2d):
processed_args.append(self.out_type(arg))
elif isinstance(arg, np.ndarray):
processed_args.append(np.broadcast_to(arg, self.shape))
else:
raise NotImplementedError()

if len(processed_args) == 2:
result = elemwise2d.binary_op(self.func, *processed_args)

return result

def _get_fill_value(self):
"""
A function that finds and returns the fill-value.
Expand All @@ -530,10 +601,11 @@
ValueError
If the fill-value is inconsistent.
"""
from ._coo import COO
from ._sparse_array import SparseArray

zero_args = tuple(
arg.fill_value[...] if isinstance(arg, COO) else arg for arg in self.args
arg.fill_value[...] if isinstance(arg, SparseArray) else arg
for arg in self.args
)

# Some elemwise functions require a dtype argument, some abhorr it.
Expand All @@ -550,7 +622,9 @@
fill_value = fill_value_array[(0,) * fill_value_array.ndim]
except IndexError:
zero_args = tuple(
arg.fill_value if isinstance(arg, COO) else _zero_of_dtype(arg.dtype)
arg.fill_value
if isinstance(arg, SparseArray)
else _zero_of_dtype(arg.dtype)
for arg in self.args
)
fill_value = self.func(*zero_args, **self.kwargs)[()]
Expand Down
Loading