Skip to content

Commit

Permalink
add quicksort example
Browse files Browse the repository at this point in the history
  • Loading branch information
Akuli committed Dec 2, 2024
1 parent 465af42 commit 1333959
Showing 1 changed file with 133 additions and 0 deletions.
133 changes: 133 additions & 0 deletions examples/quicksort.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Sorting is currently not in standard library, because it's not possible to
# make it work generically without hard-coding the type of the array. So you
# could make a sort_ints() function, but not a sort() function.
#
# Maybe this will some day become a part of the standard library :)

import "stdlib/io.jou"


def swap(a: int*, b: int*) -> None:
temp = *a
*a = *b
*b = temp


def print_array(prefix: byte*, arr: int*, length: int) -> None:
if prefix != NULL:
printf("%s=", prefix)
printf("[", prefix)
for i = 0; i < length; i++:
if i != 0:
printf(",")
printf("%d", arr[i])
printf("]")


def quicksort(arr: int*, length: int, depth: int) -> None:
if length < 2:
return

for i = 0; i < depth; i++:
printf(" ")
printf("sorting ")
print_array(NULL, arr, length)
printf(": ")

pivot = arr[length / 2]

# Delete pivot itself for now. This prevents getting stuck in corner cases
# where pivot is largest or smallest.
#
# For example, consider sorting [3,2,9,1,4]. Instead of this:
#
# [3,2,9,1,4] <=pivot
# [] ==pivot
# [] >=pivot
#
# we get this:
#
# [3,2,1,4] <=pivot
# [] ==pivot
# [] >=pivot
#
# Once we add back the pivot later, this becomes:
#
# [3,2,1,4,9] <=pivot
# [9] ==pivot
# [9] >=pivot
arr[length / 2] = arr[length - 1]
length--

# Rearrange elements so that:
# - arr[0..end_of_small] <= pivot
# - arr[start_of_big..pivot_temp] >= pivot
# - the above two ranges cover the entire array (they may overlap)
#
# This looks like:
#
# end_of_small
# ↓
# [..........] <=pivot
# [...] ==pivot
# [......] >=pivot
# ↑
# start_of_big
#
# We start with empty arrays and expand:
#
# [] <=pivot
# [] >=pivot
end_of_small = 0
start_of_big = length

while True:
# extend ranges without moving elements, if possible
while end_of_small < length and arr[end_of_small] <= pivot:
end_of_small++
while start_of_big > 0 and arr[start_of_big - 1] >= pivot:
start_of_big--

if end_of_small >= start_of_big:
# whole array covered
break

# neither range can expand because of >pivot and <pivot elements
assert arr[end_of_small] > pivot
assert arr[start_of_big - 1] < pivot
swap(&arr[end_of_small], &arr[start_of_big - 1])

# Add back the removed pivot. It becomes a part of the overlap.
arr[length++] = arr[end_of_small]
arr[end_of_small++] = pivot

print_array("smaller", &arr[0], start_of_big)
printf(" ")
print_array("pivot", &arr[start_of_big], end_of_small - start_of_big)
printf(" ")
print_array("bigger", &arr[end_of_small], length - end_of_small)
printf("\n")

# Sort subarrays recursively. The overlapping part is equal to pivot and
# doesn't need sorting. Arrays are now smaller because overlap is no longer
# empty (it now contains the pivot).
quicksort(&arr[0], start_of_big, depth + 1)
quicksort(&arr[end_of_small], length - end_of_small, depth + 1)


def main() -> int:
arr = [16, 19, 9, 1, 19, 10, 8, 1, 0, 14]

# Output: sorting [16,19,9,1,19,10,8,1,0,14]: smaller=[0,1,9,1,8] pivot=[10] bigger=[19,19,16,14]
# Output: sorting [0,1,9,1,8]: smaller=[0,1,8,1] pivot=[9] bigger=[]
# Output: sorting [0,1,8,1]: smaller=[0,1,1] pivot=[8] bigger=[]
# Output: sorting [0,1,1]: smaller=[0] pivot=[1,1] bigger=[]
# Output: sorting [19,19,16,14]: smaller=[14] pivot=[16] bigger=[19,19]
# Output: sorting [19,19]: smaller=[] pivot=[19,19] bigger=[]
quicksort(arr, (sizeof(arr) / sizeof(arr[0])) as int, 0)

# Output: [0,1,1,8,9,10,14,16,19,19]
print_array(NULL, arr, (sizeof(arr) / sizeof(arr[0])) as int)
printf("\n")

return 0

0 comments on commit 1333959

Please sign in to comment.