Quick Sort

Introduction

Quick sort is one of the most popular algorithms for sorting. It’s not uncommon to see this algorithm or other algorithms(e.g. partition) leveraging the principle of quick sort during interview.

Algorithm

The idea behind quick sort is divide and conquer. First we make the whole array “ordered” on high level, and then we make several subarrays which consititute the whole array “ordered”. Within each subarray, there’re more smaller sub-subarrays. We keep doing so until the whole array is strictly sorted.

Here is how this algorithm works in detail. First we pick a pivot element, then we partition the array into 2 sections. In the first section, all elements will be smaller or equal to the pivot. For the second section, all elements will be greater or equal to the pivot.

[x, x, x, x, x, x, x | y, y, y, y, y, y, y]
        <= pivot             >= pivot

So “globally” speaking, this array is kind of sorted, because all elements on left hand side are <= pivot, and all elements on right hand side are >= pivot. Then we recursively go to left part and right part for further partitioning.

Question is: why not we let left part strictly less than pivot, or right part strictly greater than pivot? Let’s look at an edge case example:

[1, 1, 1, 1, 1, 1, 2], pivot = 1
Let's say we want all elements less than 1 on left side. After partition, we get:
[| 1, 1, 1, 1, 1, 1, 2]

As we can see, all elements stay in right part. We’re not narrowing down the problem’s input size, which is a nightmare for recursion. We’ll get stack overflow exception. We’re trying to evenly distribute elements to left and right parts, thus once the element equals to pivot, it has chance to go either left part of right part. Here’s the solution to leetcode question 912 which I use quick sort for implementatinon:

class Solution:
    def sortArray(self, nums: List[int]) -> List[int]:
        if not nums:
            return nums
        self.quick_sort(nums, 0, len(nums) - 1)
        return nums
        
    def quick_sort(self, nums, start, end):
        if start >= end:
            return
        left, right = start, end
        pivot = nums[(left + right) // 2]
        while left <= right:
            while left <= right and nums[left] < pivot:
                left += 1
            while left <= right and nums[right] > pivot:
                right -= 1
            if left <= right:
                temp = nums[left]
                nums[left] = nums[right]
                nums[right] = temp
                left += 1
                right -= 1
        
        self.quick_sort(nums, start, right)
        self.quick_sort(nums, left, end)

You may ask why I choose nums[(left + right) / 2] as pivot. We have several other options here for choosing pivot. nums[left], nums[right]or an element at random index. If we choose nums[left] and array is in ascending order, we encounter the worst case since every time we partition array into 2 partitions, one with size 1, and the other with size n - 1. And time complexity will be O(n^2). Same for nums[right]. For random index, we still have a cost for running the random algorithm, thus nums[(left + right) / 2] seems to be the best choice here.

The next question you may have is why left <= right instead of left < right? Let’s take a look at this example:

[1, 2] pivot = 1
 ^  ^
 L  R

During partition, first we move the L pointer to find the element that is >= 1, thus L stays at it’s original position. Then we move the R pointer to left to find the element that is <= 1, thus R will be moving to 1 as well:

[1,    2]
 ^^
 LR

Then we need to call self.quick_sort(nums, start, right), which is (0, 0). And then call self.quick_sort(nums, left, end), which is (0, 1), so we don’t shrink the problem’s size, thus we’ll see stackoverflow exception.

Summary

Several key points:

  • When picking the pivot, we need the value, not the index. pivot = nums[(left + right) // 2]
  • while left <= right, instead of while left < right.
  • nums[left] < pivot, nums[left] > pivot.

Scroll to Top