A clear explanation of finding the kth smallest value in a row-sorted and column-sorted matrix using binary search on values.
Problem Restatement
We are given an n x n matrix.
Each row is sorted in non-decreasing order.
Each column is also sorted in non-decreasing order.
We need to return the kth smallest element in the whole matrix. Duplicates count separately, so this means the kth element in the fully sorted list, not the kth distinct value. The problem also asks for memory better than O(n^2).
Input and Output
| Item | Meaning |
|---|---|
| Input | A sorted n x n matrix and an integer k |
| Output | The kth smallest value |
| Matrix property | Rows and columns are sorted |
| Duplicate values | Count separately |
| Constraint | 1 <= n <= 300 |
| Constraint | -10^9 <= matrix[i][j] <= 10^9 |
| Constraint | 1 <= k <= n^2 |
Example function shape:
def kthSmallest(matrix: list[list[int]], k: int) -> int:
...Examples
Example 1:
matrix = [
[1, 5, 9],
[10, 11, 13],
[12, 13, 15],
]
k = 8If we flatten and sort the matrix, we get:
[1, 5, 9, 10, 11, 12, 13, 13, 15]The 8th smallest value is:
13Example 2:
matrix = [[-5]]
k = 1There is only one value, so the answer is:
-5First Thought: Flatten and Sort
The simplest idea is to put every value into one array, sort it, and return index k - 1.
class Solution:
def kthSmallest(self, matrix: list[list[int]], k: int) -> int:
values = []
for row in matrix:
for x in row:
values.append(x)
values.sort()
return values[k - 1]This works, but it stores all n^2 values.
The problem asks for memory better than O(n^2), so we should use the sorted structure of the matrix.
Key Insight
The matrix is sorted by rows and columns.
Instead of binary searching over positions, we binary search over values.
The smallest possible answer is:
matrix[0][0]The largest possible answer is:
matrix[n - 1][n - 1]For any value mid, we can count how many numbers in the matrix are less than or equal to mid.
If at least k numbers are <= mid, then the answer is at most mid.
If fewer than k numbers are <= mid, then the answer is larger than mid.
So the problem becomes:
Find the smallest value x such that at least k matrix elements are <= x.
Counting Values Less Than or Equal to mid
We can count in O(n) time using the bottom-left corner.
Start at:
row = n - 1
col = 0At each step:
If:
matrix[row][col] <= midthen every value above it in the same column is also <= mid, because the column is sorted.
So we add:
row + 1and move right.
Otherwise, the current value is too large, so we move up.
Algorithm
Set:
left = matrix[0][0]
right = matrix[n - 1][n - 1]Then binary search while left < right.
For each mid:
- Count how many values in the matrix are
<= mid. - If the count is at least
k, moverighttomid. - Otherwise, move
lefttomid + 1.
At the end, left is the smallest value with at least k values less than or equal to it.
That value is the answer.
Correctness
Define count(x) as the number of matrix elements less than or equal to x.
As x increases, count(x) never decreases.
So the condition:
count(x) >= kis monotonic.
For small values, the condition may be false.
For large values, the condition becomes true.
The kth smallest value is exactly the first value where this condition becomes true.
The binary search keeps this invariant:
| Bound | Meaning |
|---|---|
left | The answer cannot be smaller than this |
right | The answer can still be this value or smaller |
When count(mid) >= k, there are already enough values less than or equal to mid, so the answer is mid or smaller.
When count(mid) < k, not enough values are less than or equal to mid, so the answer must be larger.
When the search ends, left == right, and this value is the smallest value whose count is at least k.
Therefore it is the kth smallest matrix element.
Complexity
Let n be the matrix size.
Let R = matrix[n - 1][n - 1] - matrix[0][0].
| Metric | Value | Why |
|---|---|---|
| Time | O(n log R) | Each binary search step counts in O(n) |
| Space | O(1) | Only a few variables are used |
Implementation
class Solution:
def kthSmallest(self, matrix: list[list[int]], k: int) -> int:
n = len(matrix)
def count_less_equal(x: int) -> int:
row = n - 1
col = 0
count = 0
while row >= 0 and col < n:
if matrix[row][col] <= x:
count += row + 1
col += 1
else:
row -= 1
return count
left = matrix[0][0]
right = matrix[n - 1][n - 1]
while left < right:
mid = (left + right) // 2
if count_less_equal(mid) >= k:
right = mid
else:
left = mid + 1
return leftCode Explanation
The helper function counts how many values are less than or equal to x.
def count_less_equal(x: int) -> int:We start from the bottom-left corner:
row = n - 1
col = 0If the current value is small enough:
if matrix[row][col] <= x:then all values above it in the same column are also small enough.
So we add:
count += row + 1and move to the next column:
col += 1If the current value is too large, we move upward:
row -= 1The binary search checks values, not indices:
while left < right:If there are at least k values less than or equal to mid, we keep the left half:
right = midOtherwise, the answer must be larger:
left = mid + 1Finally:
return leftTesting
def test_solution():
s = Solution()
assert s.kthSmallest(
[
[1, 5, 9],
[10, 11, 13],
[12, 13, 15],
],
8,
) == 13
assert s.kthSmallest([[-5]], 1) == -5
assert s.kthSmallest(
[
[1, 2],
[1, 3],
],
2,
) == 1
assert s.kthSmallest(
[
[1, 2],
[3, 4],
],
4,
) == 4
assert s.kthSmallest(
[
[-5, -4],
[-3, -1],
],
3,
) == -3
print("all tests passed")
test_solution()Test meaning:
| Test | Why |
|---|---|
Main 3 x 3 example | Checks duplicate 13 is counted separately |
| Single element | Minimum matrix size |
| Duplicate values | Confirms kth means sorted order, not distinct order |
k = n^2 | Returns the largest value |
| Negative values | Confirms binary search works across negative ranges |