Leetcode 4: Median of Two Sorted Arrays

The original question can be found here.

The question is asking to return the median number of 2 sorted arrays. Let’s say array nums1‘s length is m and nums2‘s length is n. If the sum of m and n is odd, we return (m + n) / 2 + 1 th element. If the sum if even, we return the average of (m + n) / 2 th and (m + n) / 2 + 1 th element.

Merge

We can create a new array with length m+n, and start the “merge sort” process until we meet the median element(s). But actually we don’t need to create the array. We simply keep track of both arrays’ current looping indexes.

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        nums1_index, nums2_index = 0, 0
        left, right = 0, 0
        for i in range((len(nums1) + len(nums2)) // 2 + 1):
            left = right
            if nums1_index >= len(nums1):
                right = nums2[nums2_index]
                nums2_index += 1
            elif nums2_index < len(nums2) and nums2[nums2_index] <= nums1[nums1_index]:
                right = nums2[nums2_index]
                nums2_index += 1
            else:
                right = nums1[nums1_index]
                nums1_index += 1
        
        if (len(nums1) + len(nums2)) % 2 == 1:
            return right
        else:
            return (left + right) / 2

Here we keep track of nums1 and nums2’s indexes for looping. We use 2 variables, left and right to avoid the case that length sum of 2 arrays are even. Time complexity: O(m + n). Space complexity: O(1).

Binary Search

This quesiton expects us to give a solution with O(log(m + n)), so if we infer the algorithm based on the time complexity, we will think about binary search since it’s time complexity is O(log(n)), especially we are given the condition that both arrays are sorted. But how we do binary search on 2 arrays?

Since we need to find the Kth element in the combination of two arrays, if each time, after O(1) operatino, we can shrink the size of input by k/2, we can achieve the time complexity of O(log(k)). Since in this question, k = (m + n) / 2, thus O(log(m + n)) time complexity. Our goal is to find the (m + n) / 2 or (m + n) / 2 + 1 th element. Let’s take a look at several cases:

k = 5, so k / 2 - 1 = 2, we look at 2nd element at each array:
nums1:        [1, 3, 5, 7, 9]
                  ^
nums2:        [2, 4, 6, 8, 10]
                  ^
Since 3 is less than 4, we know that we don't consider 1 and 3 for finding the 5th element, so next we'll find the 3rd element based on the following 2 arrays:
k = 3
nums1:        [5, 7, 9]
nums2:        [2, 4, 6, 8, 10]

We keep this process until one of the conditions met:

  • One array is running out of numbers
  • k is 1

For first condition, we simply return the Kth element of the array which still has numbers left. For second condition, we compare the 1st number of 2 arrays and return the smaller one.

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        m, n = len(nums1), len(nums2)
        if (m + n) % 2 == 1:
            return self.findKthElement(nums1, 0, nums2, 0, (m + n) // 2 + 1)
        else:
            return (self.findKthElement(nums1, 0, nums2, 0, (m + n) // 2)
                    + self.findKthElement(nums1, 0, nums2, 0, (m + n) // 2 + 1)) / 2

    def findKthElement(self, nums1, nums1_index, nums2, nums2_index, k):
        if nums1_index >= len(nums1):
            return nums2[nums2_index + k - 1]
        
        if nums2_index >= len(nums2):
            return nums1[nums1_index + k - 1]

        if k == 1:
            return min(nums1[nums1_index], nums2[nums2_index])

        num1, num2 = math.inf, math.inf

        if nums1_index + k // 2 - 1 < len(nums1):
            num1 = nums1[nums1_index + k // 2 - 1]
        
        if nums2_index + k // 2 - 1 < len(nums2):
            num2 = nums2[nums2_index + k // 2 - 1]
        
        if num1 <= num2:
            return self.findKthElement(nums1, nums1_index + k // 2, nums2, nums2_index, k - k // 2);
        else:
            return self.findKthElement(nums1, nums1_index, nums2, nums2_index + k // 2, k - k // 2);

Time complexity: O(log(m + n)). Space complexity: O(log(m + n)) because findKthElement is a recursive stack call which takes O(log(m+n)) height.

Scroll to Top