diff --git a/piquasso/_math/combinatorics.py b/piquasso/_math/combinatorics.py index cf4e5772..84fe8881 100644 --- a/piquasso/_math/combinatorics.py +++ b/piquasso/_math/combinatorics.py @@ -46,36 +46,6 @@ def comb(n, k): return prod -@nb.njit -def nb_combinations(arr, r): - n = arr.shape[0] - indices = np.arange(r) - result_size = comb(n, r) - result = np.empty((result_size, r), dtype=arr.dtype) - - def advance(indices, n, r): - for i in range(r - 1, -1, -1): - if indices[i] != i + n - r: - break - else: - return False - - indices[i] += 1 - for j in range(i + 1, r): - indices[j] = indices[j - 1] + 1 - - return True - - result[0, :] = arr[indices] - k = 1 - - while advance(indices, n, r): - result[k, :] = arr[indices] - k += 1 - - return result - - _T = TypeVar("_T") @@ -84,37 +54,6 @@ def powerset(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: return chain.from_iterable(combinations(iterable, r) for r in range(len(s) + 1)) -@nb.njit -def partitions(boxes, particles): - size = boxes + particles - 1 - - if size == -1 or boxes == 0: - return np.empty((1, 0), dtype=np.int32) - - index_matrix = nb_combinations( - np.array(list(range(size)), dtype=np.int32), boxes - 1 - ) - index_matrix = np.flipud(index_matrix) - - starts = np.concatenate( - ( - np.zeros(shape=(index_matrix.shape[0], 1), dtype=np.int32), - np.add(index_matrix, 1), - ), - axis=1, - ) - - stops = np.concatenate( - ( - index_matrix, - np.full(shape=(index_matrix.shape[0], 1), fill_value=size, dtype=np.int32), - ), - axis=1, - ) - - return (stops - starts).astype(np.int32) - - def is_even_permutation(permutation): permutation = list(permutation) length = len(permutation) @@ -129,3 +68,38 @@ def is_even_permutation(permutation): elements_seen[current] = True current = permutation[current] return (length - cycles) % 2 == 0 + + +@nb.njit +def partitions(boxes, particles): + positions = particles + boxes - 1 + + if positions == -1 or boxes == 0: + return np.empty((1, 0), dtype=np.int32) + + size = comb(positions, boxes - 1) + result = np.empty((size, boxes), dtype=np.int32) + separators = np.arange(boxes - 1, dtype=np.int32) + index = size - 1 + + while True: + prev = -1 + for i in range(boxes - 1): + result[index, i] = separators[i] - prev - 1 + prev = separators[i] + + result[index, boxes - 1] = positions - prev - 1 + index -= 1 + + if index < 0: + break + + i = boxes - 2 + while separators[i] == positions - (boxes - 1 - i): + i -= 1 + + separators[i] += 1 + for j in range(i + 1, boxes - 1): + separators[j] = separators[j - 1] + 1 + + return result