Skip to content

Commit

Permalink
fix ranges, use numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
MehmedGIT committed Dec 4, 2023
1 parent 0ea71fd commit 8171afb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
24 changes: 7 additions & 17 deletions ocrd_utils/ocrd_utils/str.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,7 @@
from .constants import REGEX_FILE_ID
from .deprecate import deprecation_warning
from warnings import warn
from math import ceil
import sys
from itertools import islice

if sys.version_info >= (3, 12):
from itertools import batched
else:
def batched(iterable, chunk_size):
iterator = iter(iterable)
chunk = None
while True:
chunk = tuple(islice(iterator, chunk_size))
if not chunk:
break
yield chunk
from numpy import array_split

__all__ = [
'assert_file_grp_cardinality',
Expand Down Expand Up @@ -224,6 +210,7 @@ def generate_range(start, end):
ret.append(start.replace(start_num, str(i).zfill(len(start_num))))
return ret


def partition_list(lst, chunks, chunk_index=None):
"""
Partition a list into roughly equally-sized chunks
Expand All @@ -240,8 +227,11 @@ def partition_list(lst, chunks, chunk_index=None):
"""
if not lst:
return []
items_per_chunk = ceil(len(lst) / chunks)
ret = list(map(list, batched(lst, items_per_chunk)))
# Catch potential empty ranges returned by numpy.array_split
# which are problematic in the ocr-d scope
if chunks > len(lst):
raise ValueError("Amount of chunks bigger than list size")
ret = [x.tolist() for x in array_split(lst, chunks)]
if chunk_index is not None:
return [ret[chunk_index]]
return ret
17 changes: 14 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,22 @@ def test_partition_list():
assert partition_list(None, 1) == []
assert partition_list([], 1) == []
assert partition_list(lst_10, 1) == [lst_10]
assert partition_list(lst_10, 3) == [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10]]
assert partition_list(lst_10, 3, 1) == [[5, 6, 7, 8]]
assert partition_list(lst_10, 3) == [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]
assert partition_list(lst_10, 3, 1) == [[5, 6, 7]]
assert partition_list(lst_10, 3, 0) == [[1, 2, 3, 4]]
with raises(IndexError):
partition_list(lst_10, 5, 5)
partition_list(lst_10, chunks=4, chunk_index=5)
partition_list(lst_10, chunks=5, chunk_index=5)
partition_list(lst_10, chunks=5, chunk_index=6)
with raises(ValueError):
partition_list(lst_10, chunks=11)
# odd prime number tests
lst_13 = list(range(1, 14))
assert partition_list(lst_13, chunks=2) == [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13]]
assert partition_list(lst_13, chunks=3) == [[1, 2, 3, 4, 5], [6, 7, 8, 9], [10, 11, 12, 13]]
assert partition_list(lst_13, chunks=4) == [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10], [11, 12, 13]]
assert partition_list(lst_13, chunks=4, chunk_index=1) == [[5, 6, 7]]


if __name__ == '__main__':
main(__file__)

0 comments on commit 8171afb

Please sign in to comment.