diff --git a/graphblas/core/ss/matrix.py b/graphblas/core/ss/matrix.py index 721e66e34..245c6c276 100644 --- a/graphblas/core/ss/matrix.py +++ b/graphblas/core/ss/matrix.py @@ -680,13 +680,18 @@ def iterkeys(self, seek=0): next_func = lib.GxB_Matrix_Iterator_next row_ptr = ffi_new("GrB_Index*") col_ptr = ffi_new("GrB_Index*") - while info == success: - key_func(it, row_ptr, col_ptr) - yield (row_ptr[0], col_ptr[0]) - info = next_func(it) - lib.GxB_Iterator_free(it_ptr) - if info != lib.GxB_EXHAUSTED: # pragma: no cover - raise _error_code_lookup[info]("Matrix iterator failed") + try: + while info == success: + key_func(it, row_ptr, col_ptr) + yield (row_ptr[0], col_ptr[0]) + info = next_func(it) + except GeneratorExit: + pass + else: + if info != lib.GxB_EXHAUSTED: # pragma: no cover + raise _error_code_lookup[info]("Matrix iterator failed") + finally: + lib.GxB_Iterator_free(it_ptr) def itervalues(self, seek=0): """Iterate over all the values of a Matrix. @@ -710,12 +715,17 @@ def itervalues(self, seek=0): info = success = lib.GrB_SUCCESS val_func = getattr(lib, f"GxB_Iterator_get_{self._parent.dtype.name}") next_func = lib.GxB_Matrix_Iterator_next - while info == success: - yield val_func(it) - info = next_func(it) - lib.GxB_Iterator_free(it_ptr) - if info != lib.GxB_EXHAUSTED: # pragma: no cover - raise _error_code_lookup[info]("Matrix iterator failed") + try: + while info == success: + yield val_func(it) + info = next_func(it) + except GeneratorExit: + pass + else: + if info != lib.GxB_EXHAUSTED: # pragma: no cover + raise _error_code_lookup[info]("Matrix iterator failed") + finally: + lib.GxB_Iterator_free(it_ptr) def iteritems(self, seek=0): """Iterate over all the row, column, and value triples of a Matrix. @@ -742,13 +752,18 @@ def iteritems(self, seek=0): next_func = lib.GxB_Matrix_Iterator_next row_ptr = ffi_new("GrB_Index*") col_ptr = ffi_new("GrB_Index*") - while info == success: - key_func(it, row_ptr, col_ptr) - yield (row_ptr[0], col_ptr[0], val_func(it)) - info = next_func(it) - lib.GxB_Iterator_free(it_ptr) - if info != lib.GxB_EXHAUSTED: # pragma: no cover - raise _error_code_lookup[info]("Matrix iterator failed") + try: + while info == success: + key_func(it, row_ptr, col_ptr) + yield (row_ptr[0], col_ptr[0], val_func(it)) + info = next_func(it) + except GeneratorExit: + pass + else: + if info != lib.GxB_EXHAUSTED: # pragma: no cover + raise _error_code_lookup[info]("Matrix iterator failed") + finally: + lib.GxB_Iterator_free(it_ptr) def export(self, format=None, *, sort=False, give_ownership=False, raw=False): """ diff --git a/graphblas/core/ss/vector.py b/graphblas/core/ss/vector.py index c6162d782..79da58c97 100644 --- a/graphblas/core/ss/vector.py +++ b/graphblas/core/ss/vector.py @@ -376,12 +376,17 @@ def iterkeys(self, seek=0): info = success = lib.GrB_SUCCESS key_func = lib.GxB_Vector_Iterator_getIndex next_func = lib.GxB_Vector_Iterator_next - while info == success: - yield key_func(it) - info = next_func(it) - lib.GxB_Iterator_free(it_ptr) - if info != lib.GxB_EXHAUSTED: # pragma: no cover - raise _error_code_lookup[info]("Vector iterator failed") + try: + while info == success: + yield key_func(it) + info = next_func(it) + except GeneratorExit: + pass + else: + if info != lib.GxB_EXHAUSTED: # pragma: no cover + raise _error_code_lookup[info]("Vector iterator failed") + finally: + lib.GxB_Iterator_free(it_ptr) def itervalues(self, seek=0): """Iterate over all the values of a Vector. @@ -405,12 +410,17 @@ def itervalues(self, seek=0): info = success = lib.GrB_SUCCESS val_func = getattr(lib, f"GxB_Iterator_get_{self._parent.dtype.name}") next_func = lib.GxB_Vector_Iterator_next - while info == success: - yield val_func(it) - info = next_func(it) - lib.GxB_Iterator_free(it_ptr) - if info != lib.GxB_EXHAUSTED: # pragma: no cover - raise _error_code_lookup[info]("Vector iterator failed") + try: + while info == success: + yield val_func(it) + info = next_func(it) + except GeneratorExit: + pass + else: + if info != lib.GxB_EXHAUSTED: # pragma: no cover + raise _error_code_lookup[info]("Vector iterator failed") + finally: + lib.GxB_Iterator_free(it_ptr) def iteritems(self, seek=0): """Iterate over all the indices and values of a Vector. @@ -435,12 +445,17 @@ def iteritems(self, seek=0): key_func = lib.GxB_Vector_Iterator_getIndex val_func = getattr(lib, f"GxB_Iterator_get_{self._parent.dtype.name}") next_func = lib.GxB_Vector_Iterator_next - while info == success: - yield (key_func(it), val_func(it)) - info = next_func(it) - lib.GxB_Iterator_free(it_ptr) - if info != lib.GxB_EXHAUSTED: # pragma: no cover - raise _error_code_lookup[info]("Vector iterator failed") + try: + while info == success: + yield (key_func(it), val_func(it)) + info = next_func(it) + except GeneratorExit: + pass + else: + if info != lib.GxB_EXHAUSTED: # pragma: no cover + raise _error_code_lookup[info]("Vector iterator failed") + finally: + lib.GxB_Iterator_free(it_ptr) def export(self, format=None, *, sort=False, give_ownership=False, raw=False): """ diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index 940658ebc..da9cbdfc5 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -3415,6 +3415,9 @@ def test_iteration(A): assert len(list(A.ss.iterkeys(N + 2))) == 0 assert len(list(A.ss.iterkeys(-N))) == N assert len(list(A.ss.itervalues(-N - 1))) == N + assert next(A.ss.iterkeys()) in A + assert next(A.ss.itervalues()) is not None + assert next(A.ss.iteritems()) is not None def test_udt(): diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index 8e010fe0b..a85347a05 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -2101,6 +2101,9 @@ def test_iteration(v): assert len(list(v.ss.iterkeys(2))) == 2 assert len(list(v.ss.itervalues(N))) == 0 assert len(list(v.ss.iteritems(N + 1))) == 0 + assert next(v.ss.iterkeys()) in v + assert next(v.ss.itervalues()) is not None + assert next(v.ss.iteritems()) is not None def test_broadcasting(A, v):