Kth Smallest Element in a Sorted Matrix

Patrick Leaton
Problem Description

Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Example:

matrix = [
   [ 1,  5,  9],
   [10, 11, 13],
   [12, 13, 15]
],
k = 8,

return 13.

 

Note:
You may assume k is always valid, 1 ≤ k ≤ n^2.

 

The description was taken from https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/.

Problem Solution

#Heap
#O(Min(K,N) + K Log(Min(K,N)) Time, O(Min(K,N)) Space
class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
       
        heap = []
       
        for row in range(0, min(k, len(matrix))):
            heap.append((matrix[row][0], row, 0))
 
        while k:
            element, row, col = heapq.heappop(heap)
            if col < len(matrix) - 1:
                heapq.heappush(heap, (matrix[row][col+1], row, col+1))
            k -= 1
       
        return element
 
#Binary Search
#O(N (Log (Max-Min Cell))) Time, O(1) Space
class Solution:
    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        n = len(matrix)
        start = matrix[0][0]
        end = matrix[n-1][n-1]
       
        while start < end:
            mid = (start + end) // 2
            if self.mid_count(mid, matrix) < k:
                start = mid + 1
            else:
                end = mid
        return start
       
    def mid_count(self, mid:int, matrix:list) -> int:
        n = len(matrix)
        count, row, col = 0, 0, n-1
        while row < n and col >= 0:
            if matrix[row][col] <= mid:
                count += col + 1
                row += 1
            else:
                col -= 1
        return count

Problem Explanation


Seeing as the problem is asking for the Kth Smallest Element, usually with these types of questions, we would use a min-heap.

However, we are also looking for a value within a sorted matrix, so that tells us we could probably use binary search as well.  

We will have both implementations here.  The binary search is a tad bit better on time, way better on space, but is much trickier to implement.


Heap Solution

Okay, so what we will do for the heap implementation is almost like a modified Breadth-First Search using a heap.

Let's initialize our heap first.

        heap = []

 

Now, we will add the first cell of every row into the heap until we reach the minimum between k and the length of the matrix.  Let's take this input matrix for example, with the k value of seven.

2  
8   14 17
20 22 23

 

We will need to make sure that we keep track of not only the value but the row and column that the value is in so that we can add its neighbors.

        for row in range(0, min(k, len(matrix))):
            heap.append((matrix[row][0], row, 0))

 

After we have added these initial cell tuples into the heap, we will continue to pop a cell tuple off the heap, check if the column is within the bounds of the length of the first row, which would be the right side of the matrix, and push each cell's right-side neighbors into the heap. 

2   4   6  
8   14 17
20 22 23

 

        while k:
            element, row, col = heapq.heappop(heap)
            if col < len(matrix) - 1:
                heapq.heappush(heap, (matrix[row][col+1], row, col+1))
            k -= 1
 

Once we have decremented k to zero,  the last element we popped off the heap will be the kth smallest element.

2   4   6  
8   14 17
20 22 23

 

        return element


Adding and removing elements to a heap costs Log(N) time and we are doing that for each row between the minimum of k and n, since that is the number of cells we added from the first column and will ultimately process until we find the kth smallest element.

This gives us a time complexity of O(Min(K,N) + K Log(Min(K,N)), and space complexity of O(Min(K,N)).



Binary Search Solution

Since the matrix is sorted, we can perform a binary search to find the kth smallest element.

What we can do is have a starting pointer at the first cell of the matrix, and an ending pointer at the last cell.  We will then calculate a mean value between these two elements.  

We will then utilize that mean value by passing it to a helper function and returning the count of elements up to that value.  If the count is lower than k then we will move the starting pointer further down the matrix.  If it is greater than or equal to k, we will move the ending pointer higher up the matrix.  Once the two pointers meet, they will fall on the kth smallest element.


Let's start off by creating the mid_count helper function.

The function will take a mid value and a matrix as arguments.  It will return an integer.

    def mid_count(self, mid:int, matrix:list) -> int:

 

We'll calculate an n value, which will be both the length and the width of the matrix since it is given as an nxn value.  We will want this for our traversal bounds.

        n = len(matrix)

 

We will then initialize our counter and our pointers for the row and column.  

        count, row, col = 0, 0, n-1

 

While our traversal is in the bounds of the matrix, we will continue calculating our count for the mid value that was passed into the function.

        while row < n and col >= 0:

 

We will start at the last cell of the first row and within each iteration, we will compare the current element to the mid value.  

If the value is less than or equal to mid, we can increment the count by the column value plus one.  The reason we can do this is that the column value tells us how many elements are in each row.  If we start at matrix[0][2] and the current element is lower than mid, we will want to continue counting and jump down to the next row.  If we increment the count by two plus one, that lets us know that we just counted the three smallest elements.  The plus one is due to array indices starting from zero.

            if matrix[row][col] <= mid:
                count += col + 1
                row += 1

 

If the current element is greater than mid, that means we have overshot our target and should take a step to the left by decrementing the col value.

            else:
                col -= 1

 

Once we have the place count of the element, we will return it for our binary search.

        return count


The rest of the code is just a traditional binary search.

We will get our n value then set the starting and ending pointers at the first and last cell of the matrix.

        n = len(matrix)
        start = matrix[0][0]
        end = matrix[n-1][n-1]

 

While the start pointer hasn't reached the end pointer, we will calculate the mean value of these two elements and get the count of where that value would fall within the matrix, this value may not actually be an element.

        while start < end:
            mid = (start + end) // 2

 

If the count is less than k, we will move our start pointer to one element past mid.

            if self.mid_count(mid, matrix) < k:
                start = mid + 1

 

Otherwise, we will move the ending pointer to mid.

            else:
                end = mid

 

Moving these pointers is just our way of decreasing the bounds of our search by half in each iteration.

Once the two pointers have met, we can return either, as they both will be on the kth smallest element.

        return start