Quick Select

Introduction

Quick select is an algorithm used to quickly locate the Kth largest/smallest element in an array. Technically speaking, it’s part of quick sort algorithm.

Algorithm

Let’s say we want to find the Kth largest element in an array. Every time we iterate through the array, we select a pivot element and divide the array into 2 parts: first part with elements >= pivot and second part with elements <= pivot. Let’s say the first part has n elements. If K <= n, we recursively go to the first part to find the Kth largest element since we know the answer won’t be in second part. If K > n, we recursively go to the second part to find K - n‘s largest element.

class Solution:
    def findKthLargest(self, nums: List[int], k: int) -> int:
        if not nums:
            return -1
        return self.quick_select(nums, 0, len(nums) - 1, k)

    def quick_select(self, nums: List[int], start_index: int, end_index: int, k: int) -> int:
        if start_index == end_index:
            return nums[start_index]
        
        pivot = nums[(start_index + end_index) // 2]

        left, right = start_index, end_index
        while left <= right:
            while left <= right and nums[left] > pivot:
                left += 1
            while left <= right and nums[right] < pivot:
                right -= 1
            if left <= right:
                temp = nums[left]
                nums[left] = nums[right]
                nums[right] = temp
                left += 1
                right -= 1
        if start_index + k - 1 <= right:
            return self.quick_select(nums, start_index, right, k)

        if start_index + k - 1 >= left:
            return self.quick_select(nums, left, end_index, k - (left - start_index))
        
        return nums[right + 1]

Let’s take a look at the code from line 25 to line 31. What does line 25 and line 28 even mean, and why we return nums[right + 1] if failing to meet those 2 if statements?

Let’s look at the following example:

[x, x, x, z, y, y, y], pivot = z
          ^^
          LR

If pivot is z, and we have L and R point to this element, after swapping, we get:

[x, x, x, z, y, y, y]
 ^     ^     ^     ^
start  R     L     end

We have partitioned the array into 3 subarrays (sometimes it’s 2 subarrays, depending on input). Since k is 1 based (start from 1), for Kth largest element, it’s index will be (start + k - 1). If it’s <= R, it goes to left partition. If it’s >= L, it goes to right partition. If it does meet these 2 conditions, which means, it will be that single element between R and L, we simply return nums[R + 1].

Time Complexity

On average, we use O(n)amount of time, to shrink the input size from n to n/2, thus the time complexity of this algorithm is O(n). But for the worst case, say every time we only partition the array by removing 1 element, the time complexity can be O(n^2).

Scroll to Top