Skip to content

Commit

Permalink
perf: Rewrite partitions
Browse files Browse the repository at this point in the history
The former `partitions` function from `piquasso._math.combinatorics` got
reimplemented with a faster algorithm.
  • Loading branch information
Kolarovszki committed Oct 15, 2024
1 parent 3e07b26 commit 78403e8
Showing 1 changed file with 35 additions and 61 deletions.
96 changes: 35 additions & 61 deletions piquasso/_math/combinatorics.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,36 +46,6 @@ def comb(n, k):
return prod


@nb.njit
def nb_combinations(arr, r):
n = arr.shape[0]
indices = np.arange(r)
result_size = comb(n, r)
result = np.empty((result_size, r), dtype=arr.dtype)

def advance(indices, n, r):
for i in range(r - 1, -1, -1):
if indices[i] != i + n - r:
break
else:
return False

indices[i] += 1
for j in range(i + 1, r):
indices[j] = indices[j - 1] + 1

return True

result[0, :] = arr[indices]
k = 1

while advance(indices, n, r):
result[k, :] = arr[indices]
k += 1

return result


_T = TypeVar("_T")


Expand All @@ -84,37 +54,6 @@ def powerset(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]:
return chain.from_iterable(combinations(iterable, r) for r in range(len(s) + 1))


@nb.njit
def partitions(boxes, particles):
size = boxes + particles - 1

if size == -1 or boxes == 0:
return np.empty((1, 0), dtype=np.int32)

index_matrix = nb_combinations(
np.array(list(range(size)), dtype=np.int32), boxes - 1
)
index_matrix = np.flipud(index_matrix)

starts = np.concatenate(
(
np.zeros(shape=(index_matrix.shape[0], 1), dtype=np.int32),
np.add(index_matrix, 1),
),
axis=1,
)

stops = np.concatenate(
(
index_matrix,
np.full(shape=(index_matrix.shape[0], 1), fill_value=size, dtype=np.int32),
),
axis=1,
)

return (stops - starts).astype(np.int32)


def is_even_permutation(permutation):
permutation = list(permutation)
length = len(permutation)
Expand All @@ -129,3 +68,38 @@ def is_even_permutation(permutation):
elements_seen[current] = True
current = permutation[current]
return (length - cycles) % 2 == 0


@nb.njit
def partitions(boxes, particles):
positions = particles + boxes - 1

if positions == -1 or boxes == 0:
return np.empty((1, 0), dtype=np.int32)

size = comb(positions, boxes - 1)
result = np.empty((size, boxes), dtype=np.int32)
separators = np.arange(boxes - 1, dtype=np.int32)
index = size - 1

while True:
prev = -1
for i in range(boxes - 1):
result[index, i] = separators[i] - prev - 1
prev = separators[i]

result[index, boxes - 1] = positions - prev - 1
index -= 1

if index < 0:
break

i = boxes - 2
while separators[i] == positions - (boxes - 1 - i):
i -= 1

separators[i] += 1
for j in range(i + 1, boxes - 1):
separators[j] = separators[j - 1] + 1

return result

0 comments on commit 78403e8

Please sign in to comment.