Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.
Note that it is the kth smallest element in the sorted order, not the kth distinct element.
Example:
matrix = [ [ 1, 5, 9], [10, 11, 13], [12, 13, 15] ], k = 8, return 13.
Note:
You may assume k is always valid, 1 ≤ k ≤ n^2.
The description was taken from https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix/.
#Heap
#O(Min(K,N) + K Log(Min(K,N)) Time, O(Min(K,N)) Space
class Solution:
def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
heap = []
for row in range(0, min(k, len(matrix))):
heap.append((matrix[row][0], row, 0))
while k:
element, row, col = heapq.heappop(heap)
if col < len(matrix) - 1:
heapq.heappush(heap, (matrix[row][col+1], row, col+1))
k -= 1
return element
#Binary Search
#O(N (Log (Max-Min Cell))) Time, O(1) Space
class Solution:
def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
n = len(matrix)
start = matrix[0][0]
end = matrix[n-1][n-1]
while start < end:
mid = (start + end) // 2
if self.mid_count(mid, matrix) < k:
start = mid + 1
else:
end = mid
return start
def mid_count(self, mid:int, matrix:list) -> int:
n = len(matrix)
count, row, col = 0, 0, n-1
while row < n and col >= 0:
if matrix[row][col] <= mid:
count += col + 1
row += 1
else:
col -= 1
return count
Seeing as the problem is asking for the Kth Smallest Element, usually with these types of questions, we would use a min-heap.
However, we are also looking for a value within a sorted matrix, so that tells us we could probably use binary search as well.
We will have both implementations here. The binary search is a tad bit better on time, way better on space, but is much trickier to implement.
Okay, so what we will do for the heap implementation is almost like a modified Breadth-First Search using a heap.
Let's initialize our heap first.
heap = []
Now, we will add the first cell of every row into the heap until we reach the minimum between k
and the length of the matrix. Let's take this input matrix for example, with the k
value of seven.
2 | 4 | 6 |
8 | 14 | 17 |
20 | 22 | 23 |
We will need to make sure that we keep track of not only the value but the row and column that the value is in so that we can add its neighbors.
for row in range(0, min(k, len(matrix))):
heap.append((matrix[row][0], row, 0))
After we have added these initial cell tuples into the heap, we will continue to pop a cell tuple off the heap, check if the column is within the bounds of the length of the first row, which would be the right side of the matrix, and push each cell's right-side neighbors into the heap.
2 | 4 | 6 |
8 | 14 | 17 |
20 | 22 | 23 |
while k:
element, row, col = heapq.heappop(heap)
if col < len(matrix) - 1:
heapq.heappush(heap, (matrix[row][col+1], row, col+1))
k -= 1
Once we have decremented k
to zero, the last element we popped off the heap will be the k
th smallest element.
2 | 4 | 6 |
8 | 14 | 17 |
20 | 22 | 23 |
return element
Adding and removing elements to a heap costs Log(N)
time and we are doing that for each row between the minimum of k
and n
, since that is the number of cells we added from the first column and will ultimately process until we find the k
th smallest element.
This gives us a time complexity of O(Min(K,N) + K Log(Min(K,N))
, and space complexity of O(Min(K,N))
.
Since the matrix is sorted, we can perform a binary search to find the kth
smallest element.
What we can do is have a starting pointer at the first cell of the matrix, and an ending pointer at the last cell. We will then calculate a mean value between these two elements.
We will then utilize that mean value by passing it to a helper function and returning the count of elements up to that value. If the count is lower than k
then we will move the starting pointer further down the matrix. If it is greater than or equal to k, we will move the ending pointer higher up the matrix. Once the two pointers meet, they will fall on the kth
smallest element.
Let's start off by creating the mid_count
helper function.
The function will take a mid
value and a matrix as arguments. It will return an integer.
def mid_count(self, mid:int, matrix:list) -> int:
We'll calculate an n
value, which will be both the length and the width of the matrix since it is given as an nxn
value. We will want this for our traversal bounds.
n = len(matrix)
We will then initialize our counter and our pointers for the row and column.
count, row, col = 0, 0, n-1
While our traversal is in the bounds of the matrix, we will continue calculating our count for the mid
value that was passed into the function.
while row < n and col >= 0:
We will start at the last cell of the first row and within each iteration, we will compare the current element to the mid
value.
If the value is less than or equal to mid, we can increment the count by the column value plus one. The reason we can do this is that the column value tells us how many elements are in each row. If we start at matrix[0][2]
and the current element is lower than mid
, we will want to continue counting and jump down to the next row. If we increment the count by two plus one, that lets us know that we just counted the three smallest elements. The plus one is due to array indices starting from zero.
if matrix[row][col] <= mid:
count += col + 1
row += 1
If the current element is greater than mid
, that means we have overshot our target and should take a step to the left by decrementing the col
value.
else:
col -= 1
Once we have the place count of the element, we will return it for our binary search.
return count
The rest of the code is just a traditional binary search.
We will get our n
value then set the starting and ending pointers at the first and last cell of the matrix.
n = len(matrix)
start = matrix[0][0]
end = matrix[n-1][n-1]
While the start pointer hasn't reached the end pointer, we will calculate the mean value of these two elements and get the count of where that value would fall within the matrix, this value may not actually be an element.
while start < end:
mid = (start + end) // 2
If the count is less than k
, we will move our start pointer to one element past mid
.
if self.mid_count(mid, matrix) < k:
start = mid + 1
Otherwise, we will move the ending pointer to mid.
else:
end = mid
Moving these pointers is just our way of decreasing the bounds of our search by half in each iteration.
Once the two pointers have met, we can return either, as they both will be on the kth
smallest element.
return start