Skip to content

Commit

Permalink
Add GxB sort methods (python-graphblas#333)
Browse files Browse the repository at this point in the history
* Add GxB sort methods

* Add docs and tests; also, change to `A.ss.sort(order=order)`

* Better

* order is not keyword only

* Return values then permutations
  • Loading branch information
eriknw authored Nov 30, 2022
1 parent 68e8006 commit f43786f
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 18 deletions.
3 changes: 1 addition & 2 deletions graphblas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from .. import replace as replace_singleton
from ..dtypes import BOOL
from ..exceptions import check_status
from . import NULL, ffi
from . import NULL
from .descriptor import lookup as descriptor_lookup
from .expr import AmbiguousAssignOrExtract, Updater
from .mask import Mask
from .operator import UNKNOWN_OPCLASS, binary_from_string, find_opclass, get_typed_op
from .utils import _Pointer, libget, output_type

CData = ffi.CData
_recorder = ContextVar("recorder")
_prev_recorder = None

Expand Down
8 changes: 6 additions & 2 deletions graphblas/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def lookup(
mask_structure=False,
transpose_first=False,
transpose_second=False,
create=False,
):
key = (
output_replace,
Expand All @@ -80,7 +81,7 @@ def lookup(
transpose_first,
transpose_second,
)
if key not in _desc_map: # pragma: no cover (unnecessary)
if create or key not in _desc_map:
# We currently don't need this block of code!
# All 32 possible descriptors are currently already added to _desc_map.
# Nevertheless, this code may be useful some day, because we will want
Expand All @@ -102,5 +103,8 @@ def lookup(
check_status_carg(
lib.GrB_Descriptor_set(desc[0], field, val), "Descriptor", desc[0]
)
_desc_map[key] = Descriptor(desc[0], "custom_descriptor", *key)
rv = Descriptor(desc[0], "custom_descriptor", *key)
if not create: # pragma: no cover (unnecessary)
_desc_map[key] = rv
return rv
return _desc_map[key]
4 changes: 2 additions & 2 deletions graphblas/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,15 +1454,15 @@ def _initialize(cls):
op = getattr(indexunary, name)
typed_op = op._typed_ops[BOOL]
output_type = op.types[BOOL]
if UINT64 not in op.types:
if UINT64 not in op.types: # pragma: no branch (safety)
op.types[UINT64] = output_type
op._typed_ops[UINT64] = typed_op
op.coercions[UINT64] = BOOL
for name in ("rowindex", "colindex"):
op = getattr(indexunary, name)
typed_op = op._typed_ops[INT64]
output_type = op.types[INT64]
if UINT64 not in op.types:
if UINT64 not in op.types: # pragma: no branch (safety)
op.types[UINT64] = output_type
op._typed_ops[UINT64] = typed_op
op.coercions[UINT64] = INT64
Expand Down
20 changes: 12 additions & 8 deletions graphblas/core/ss/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,7 @@
}


def get_nthreads_descriptor(nthreads, _cache=True):
nthreads = max(0, int(nthreads))
key = ("nthreads", nthreads)
if _cache and key in _desc_map:
return _desc_map[key]
desc_obj = ffi.new("GrB_Descriptor*")
lib.GrB_Descriptor_new(desc_obj)
desc = Descriptor(desc_obj[0], f"nthreads{nthreads}")
def set_nthreads(desc, nthreads):
check_status(
lib.GxB_Desc_set(
desc._carg,
Expand All @@ -29,6 +22,17 @@ def get_nthreads_descriptor(nthreads, _cache=True):
),
desc,
)


def get_nthreads_descriptor(nthreads, _cache=True):
nthreads = max(0, int(nthreads))
key = ("nthreads", nthreads)
if _cache and key in _desc_map:
return _desc_map[key]
desc_obj = ffi.new("GrB_Descriptor*")
lib.GrB_Descriptor_new(desc_obj)
desc = Descriptor(desc_obj[0], f"nthreads{nthreads}")
set_nthreads(desc, nthreads)
if _cache:
_desc_map[key] = desc
return desc
Expand Down
90 changes: 87 additions & 3 deletions graphblas/core/ss/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

import graphblas as gb

from ... import monoid
from ...dtypes import _INDEX, BOOL, INT64, _string_to_dtype, lookup_dtype
from ... import binary, monoid
from ...dtypes import _INDEX, BOOL, INT64, UINT64, _string_to_dtype, lookup_dtype
from ...exceptions import _error_code_lookup, check_status, check_status_carg
from .. import NULL, ffi, lib
from ..base import call
from ..descriptor import lookup as descriptor_lookup
from ..operator import get_typed_op
from ..scalar import Scalar, _as_scalar, _scalar_index
from ..utils import (
_CArray,
Expand All @@ -29,7 +31,7 @@
wrapdoc,
)
from .config import BaseConfig
from .descriptor import get_compression_descriptor, get_nthreads_descriptor
from .descriptor import get_compression_descriptor, get_nthreads_descriptor, set_nthreads

ffi_new = ffi.new

Expand Down Expand Up @@ -3729,6 +3731,9 @@ def compactify_rowwise(
**THIS API IS EXPERIMENTAL AND MAY CHANGE**
See Also
--------
Matrix.ss.sort
"""
return self._compactify(
how, reverse, asindex, "ncols", ncols, "hypercsr", "col_indices", name
Expand Down Expand Up @@ -3763,6 +3768,9 @@ def compactify_columnwise(
**THIS API IS EXPERIMENTAL AND MAY CHANGE**
See Also
--------
Matrix.ss.sort
"""
return self._compactify(
how, reverse, asindex, "nrows", nrows, "hypercsc", "row_indices", name
Expand Down Expand Up @@ -3836,6 +3844,82 @@ def _compactify(self, how, reverse, asindex, nkey, nval, fmt, indices_name, name
name=name,
)

def sort(self, op=binary.lt, order="rowwise", *, values=True, permutation=True, nthreads=None):
"""GxB_Matrix_sort to sort values along the rows (default) or columns of the Matrix
Sorting moves all the elements to the left (if rowwise) or top (if columnwise) just
like `compactify`. The returned matrices will be the same shape as the input Matrix.
Parameters
----------
op : :class:`~graphblas.core.operator.BinaryOp`, optional
Binary operator with a bool return type used to sort the values.
For example, `binary.lt` (the default) sorts the smallest elements first.
Ties are broken according to indices (smaller first).
order : {"rowwise", "columnwise"}, optional
Whether to sort rowwise or columnwise. Rowwise shifts all values to the left,
and columnwise shifts all values to the top. The default is "rowwise".
values : bool, default=True
Whether to return values; will return `None` for values if `False`.
permutation : bool, default=True
Whether to compute the permutation Matrix that has the original column
indices (if rowwise) or row indices (if columnwise) of the sorted values.
Will return None if `False`.
nthreads : int, optional
The maximum number of threads to use for this operation.
None, 0 or negative nthreads means to use the default number of threads.
Returns
-------
Matrix : Values
Matrix[dtype=UINT64] : Permutation
See Also
--------
Matrix.ss.compactify
"""
from ..matrix import Matrix

order = get_order(order)
parent = self._parent
op = get_typed_op(op, parent.dtype, kind="binary")
if op.opclass == "Monoid":
op = op.binaryop
else:
parent._expect_op(op, "BinaryOp", within="sort", argname="op")
if values:
C = Matrix(parent.dtype, parent._nrows, parent._ncols, name="Values")
elif not permutation:
return None, None
else:
C = None
if permutation:
P = Matrix(UINT64, parent._nrows, parent._ncols, name="Permutation")
else:
P = None
# TODO: clean this up once we expose backend descriptors
if nthreads is not None:
if order == "rowwise":
desc = get_nthreads_descriptor(nthreads)
else:
desc = descriptor_lookup(transpose_first=True, create=True)
set_nthreads(desc, nthreads)
elif order == "rowwise":
desc = None
else:
desc = descriptor_lookup(transpose_first=True)
check_status(
lib.GxB_Matrix_sort(
C._carg if C is not None else NULL,
P._carg if P is not None else NULL,
op._carg,
parent._carg,
desc._carg if desc is not None else NULL,
),
parent,
)
return C, P

def serialize(self, compression="default", level=None, *, nthreads=None):
"""Serialize a Matrix to bytes (as numpy array) using SuiteSparse GxB_Matrix_serialize.
Expand Down
67 changes: 66 additions & 1 deletion graphblas/core/ss/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

import graphblas as gb

from ... import monoid
from ... import binary, monoid
from ...dtypes import _INDEX, INT64, UINT64, _string_to_dtype, lookup_dtype
from ...exceptions import _error_code_lookup, check_status, check_status_carg
from .. import NULL, ffi, lib
from ..base import call
from ..operator import get_typed_op
from ..scalar import Scalar, _as_scalar
from ..utils import (
_CArray,
Expand Down Expand Up @@ -1490,6 +1491,70 @@ def compactify(self, how="first", size=None, *, reverse=False, asindex=False, na
name=name,
)

def sort(self, op=binary.lt, *, values=True, permutation=True, nthreads=None):
"""GxB_Vector_sort to sort values of the Vector
Sorting moves all the elements to the left just like `compactify`.
The returned vectors will be the same size as the input Vector.
Parameters
----------
op : :class:`~graphblas.core.operator.BinaryOp`, optional
Binary operator with a bool return type used to sort the values.
For example, `binary.lt` (the default) sorts the smallest elements first.
Ties are broken according to indices (smaller first).
values : bool, default=True
Whether to return values; will return `None` for values if `False`.
permutation : bool, default=True
Whether to compute the permutation Vector that has the original indices of the
sorted values. Will return None if `False`.
nthreads : int, optional
The maximum number of threads to use for this operation.
None, 0 or negative nthreads means to use the default number of threads.
Returns
-------
Vector : values
Vector[dtype=UINT64] : permutation
See Also
--------
Vector.ss.compactify
"""
from ..vector import Vector

parent = self._parent
op = get_typed_op(op, parent.dtype, kind="binary")
if op.opclass == "Monoid":
op = op.binaryop
else:
parent._expect_op(op, "BinaryOp", within="sort", argname="op")
if values:
w = Vector(parent.dtype, parent._size, name="values")
elif not permutation:
return None, None
else:
w = None
if permutation:
p = Vector(UINT64, parent._size, name="permutation")
else:
p = None
if nthreads is not None:
desc = get_nthreads_descriptor(nthreads)
else:
desc = None
check_status(
lib.GxB_Vector_sort(
w._carg if w is not None else NULL,
p._carg if p is not None else NULL,
op._carg,
parent._carg,
desc._carg if desc is not None else NULL,
),
parent,
)
return w, p

def serialize(self, compression="default", level=None, *, nthreads=None):
"""Serialize a Vector to bytes (as numpy array) using SuiteSparse GxB_Vector_serialize.
Expand Down
1 change: 1 addition & 0 deletions graphblas/monoid/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@

# To increase import speed, only call njit when `_config.get("mapnumpy")` is False
if _config.get("mapnumpy") or type(_numba.njit(lambda x, y: _np.fmax(x, y))(1, 2)) is not float:
# Incorrect behavior was introduced in numba 0.56.2 and numpy 1.23
# See: https://github.com/numba/numba/issues/8478
_monoid_identities["fmax"].update(
{
Expand Down
Loading

0 comments on commit f43786f

Please sign in to comment.