From 8171afb95703bfd97da7efa701479d57f5dcccf3 Mon Sep 17 00:00:00 2001 From: mehmedGIT <mehmed.n.mustafa@gmail.com> Date: Mon, 4 Dec 2023 16:51:46 +0100 Subject: [PATCH] fix ranges, use numpy --- ocrd_utils/ocrd_utils/str.py | 24 +++++++----------------- tests/test_utils.py | 17 ++++++++++++++--- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/ocrd_utils/ocrd_utils/str.py b/ocrd_utils/ocrd_utils/str.py index 8979e28bb7..f5b9242d35 100644 --- a/ocrd_utils/ocrd_utils/str.py +++ b/ocrd_utils/ocrd_utils/str.py @@ -7,21 +7,7 @@ from .constants import REGEX_FILE_ID from .deprecate import deprecation_warning from warnings import warn -from math import ceil -import sys -from itertools import islice - -if sys.version_info >= (3, 12): - from itertools import batched -else: - def batched(iterable, chunk_size): - iterator = iter(iterable) - chunk = None - while True: - chunk = tuple(islice(iterator, chunk_size)) - if not chunk: - break - yield chunk +from numpy import array_split __all__ = [ 'assert_file_grp_cardinality', @@ -224,6 +210,7 @@ def generate_range(start, end): ret.append(start.replace(start_num, str(i).zfill(len(start_num)))) return ret + def partition_list(lst, chunks, chunk_index=None): """ Partition a list into roughly equally-sized chunks @@ -240,8 +227,11 @@ def partition_list(lst, chunks, chunk_index=None): """ if not lst: return [] - items_per_chunk = ceil(len(lst) / chunks) - ret = list(map(list, batched(lst, items_per_chunk))) + # Catch potential empty ranges returned by numpy.array_split + # which are problematic in the ocr-d scope + if chunks > len(lst): + raise ValueError("Amount of chunks bigger than list size") + ret = [x.tolist() for x in array_split(lst, chunks)] if chunk_index is not None: return [ret[chunk_index]] return ret diff --git a/tests/test_utils.py b/tests/test_utils.py index 3441ec52fb..d2093c465d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -309,11 +309,22 @@ def test_partition_list(): assert partition_list(None, 1) == [] assert partition_list([], 1) == [] assert partition_list(lst_10, 1) == [lst_10] - assert partition_list(lst_10, 3) == [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10]] - assert partition_list(lst_10, 3, 1) == [[5, 6, 7, 8]] + assert partition_list(lst_10, 3) == [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]] + assert partition_list(lst_10, 3, 1) == [[5, 6, 7]] assert partition_list(lst_10, 3, 0) == [[1, 2, 3, 4]] with raises(IndexError): - partition_list(lst_10, 5, 5) + partition_list(lst_10, chunks=4, chunk_index=5) + partition_list(lst_10, chunks=5, chunk_index=5) + partition_list(lst_10, chunks=5, chunk_index=6) + with raises(ValueError): + partition_list(lst_10, chunks=11) + # odd prime number tests + lst_13 = list(range(1, 14)) + assert partition_list(lst_13, chunks=2) == [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13]] + assert partition_list(lst_13, chunks=3) == [[1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]] + assert partition_list(lst_13, chunks=4) == [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10], [11, 12, 13]] + assert partition_list(lst_13, chunks=4, chunk_index=1) == [[5, 6, 7]] + if __name__ == '__main__': main(__file__)