Select the k-th smallest element from a matrix whose rows and columns are sorted.
Selection in Sorted Matrix finds the k-th smallest value in a matrix where every row and every column is sorted in nondecreasing order.
The matrix ordering gives enough structure to count how many values are less than or equal to a candidate. This allows binary search over the value range.
Problem
Given an matrix such that:
and
find the k-th smallest value, using 1-based rank.
Algorithm
Binary search over values between the smallest and largest matrix entries. For each midpoint, count how many entries are less than or equal to it.
selection_in_sorted_matrix(A, k):
low = A[0][0]
high = A[m - 1][n - 1]
while low < high:
mid = floor((low + high) / 2)
count = count_less_equal(A, mid)
if count < k:
low = mid + 1
else:
high = mid
return lowTo count efficiently, start from the bottom-left corner.
count_less_equal(A, x):
row = m - 1
col = 0
count = 0
while row >= 0 and col < n:
if A[row][col] <= x:
count = count + row + 1
col = col + 1
else:
row = row - 1
return countExample
Let:
For , the sorted order is:
The 8-th smallest value is:
Correctness
For any candidate value , count_less_equal returns the number of matrix entries at most . Because rows and columns are sorted, the set of entries less than or equal to forms a monotone region.
If fewer than values are at most , then the answer must be greater than . Otherwise, the answer is at most . Binary search maintains this invariant until the smallest feasible value remains.
Complexity
| part | cost |
|---|---|
| counting pass | |
| value binary search | |
| total time | |
| space |
Here is the numeric value range:
When to Use
Use this method when:
- rows and columns are sorted
- values are numeric and bounded
- duplicates are allowed
- you need the k-th value, not the sorted list
For very large value ranges or non-integer keys, heap based k-way merging may be preferable.
Implementation
def count_less_equal(matrix, x):
m = len(matrix)
n = len(matrix[0])
row = m - 1
col = 0
count = 0
while row >= 0 and col < n:
if matrix[row][col] <= x:
count += row + 1
col += 1
else:
row -= 1
return count
def selection_in_sorted_matrix(matrix, k):
low = matrix[0][0]
high = matrix[-1][-1]
while low < high:
mid = (low + high) // 2
if count_less_equal(matrix, mid) < k:
low = mid + 1
else:
high = mid
return lowfunc countLessEqual(matrix [][]int, x int) int {
m := len(matrix)
n := len(matrix[0])
row := m - 1
col := 0
count := 0
for row >= 0 && col < n {
if matrix[row][col] <= x {
count += row + 1
col++
} else {
row--
}
}
return count
}
func SelectionInSortedMatrix(matrix [][]int, k int) int {
low := matrix[0][0]
high := matrix[len(matrix)-1][len(matrix[0])-1]
for low < high {
mid := low + (high-low)/2
if countLessEqual(matrix, mid) < k {
low = mid + 1
} else {
high = mid
}
}
return low
}