From 8171afb95703bfd97da7efa701479d57f5dcccf3 Mon Sep 17 00:00:00 2001
From: mehmedGIT <mehmed.n.mustafa@gmail.com>
Date: Mon, 4 Dec 2023 16:51:46 +0100
Subject: [PATCH] fix ranges, use numpy

---
 ocrd_utils/ocrd_utils/str.py | 24 +++++++-----------------
 tests/test_utils.py          | 17 ++++++++++++++---
 2 files changed, 21 insertions(+), 20 deletions(-)

diff --git a/ocrd_utils/ocrd_utils/str.py b/ocrd_utils/ocrd_utils/str.py
index 8979e28bb7..f5b9242d35 100644
--- a/ocrd_utils/ocrd_utils/str.py
+++ b/ocrd_utils/ocrd_utils/str.py
@@ -7,21 +7,7 @@
 from .constants import REGEX_FILE_ID
 from .deprecate import deprecation_warning
 from warnings import warn
-from math import ceil
-import sys
-from itertools import islice
-
-if sys.version_info >= (3, 12):
-    from itertools import batched
-else:
-    def batched(iterable, chunk_size):
-        iterator = iter(iterable)
-        chunk = None
-        while True:
-            chunk = tuple(islice(iterator, chunk_size))
-            if not chunk:
-                break
-            yield chunk
+from numpy import array_split
 
 __all__ = [
     'assert_file_grp_cardinality',
@@ -224,6 +210,7 @@ def generate_range(start, end):
         ret.append(start.replace(start_num, str(i).zfill(len(start_num))))
     return ret
 
+
 def partition_list(lst, chunks, chunk_index=None):
     """
     Partition a list into roughly equally-sized chunks
@@ -240,8 +227,11 @@ def partition_list(lst, chunks, chunk_index=None):
     """
     if not lst:
         return []
-    items_per_chunk = ceil(len(lst) / chunks)
-    ret = list(map(list, batched(lst, items_per_chunk)))
+    # Catch potential empty ranges returned by numpy.array_split
+    #  which are problematic in the ocr-d scope
+    if chunks > len(lst):
+        raise ValueError("Amount of chunks bigger than list size")
+    ret = [x.tolist() for x in array_split(lst, chunks)]
     if chunk_index is not None:
         return [ret[chunk_index]]
     return ret
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 3441ec52fb..d2093c465d 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -309,11 +309,22 @@ def test_partition_list():
     assert partition_list(None, 1) == []
     assert partition_list([], 1) == []
     assert partition_list(lst_10, 1) == [lst_10]
-    assert partition_list(lst_10, 3) == [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10]]
-    assert partition_list(lst_10, 3, 1) == [[5, 6, 7, 8]]
+    assert partition_list(lst_10, 3) == [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]
+    assert partition_list(lst_10, 3, 1) == [[5, 6, 7]]
     assert partition_list(lst_10, 3, 0) == [[1, 2, 3, 4]]
     with raises(IndexError):
-        partition_list(lst_10, 5, 5)
+        partition_list(lst_10, chunks=4, chunk_index=5)
+        partition_list(lst_10, chunks=5, chunk_index=5)
+        partition_list(lst_10, chunks=5, chunk_index=6)
+    with raises(ValueError):
+        partition_list(lst_10, chunks=11)
+    # odd prime number tests
+    lst_13 = list(range(1, 14))
+    assert partition_list(lst_13, chunks=2) == [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13]]
+    assert partition_list(lst_13, chunks=3) == [[1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]]
+    assert partition_list(lst_13, chunks=4) == [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10], [11, 12, 13]]
+    assert partition_list(lst_13, chunks=4, chunk_index=1) == [[5, 6, 7]]
+
 
 if __name__ == '__main__':
     main(__file__)