Skip to content

LeetCode 378: Kth Smallest Element in a Sorted Matrix

A clear explanation of finding the kth smallest value in a row-sorted and column-sorted matrix using binary search on values.

Problem Restatement

We are given an n x n matrix.

Each row is sorted in non-decreasing order.

Each column is also sorted in non-decreasing order.

We need to return the kth smallest element in the whole matrix. Duplicates count separately, so this means the kth element in the fully sorted list, not the kth distinct value. The problem also asks for memory better than O(n^2).

Input and Output

ItemMeaning
InputA sorted n x n matrix and an integer k
OutputThe kth smallest value
Matrix propertyRows and columns are sorted
Duplicate valuesCount separately
Constraint1 <= n <= 300
Constraint-10^9 <= matrix[i][j] <= 10^9
Constraint1 <= k <= n^2

Example function shape:

def kthSmallest(matrix: list[list[int]], k: int) -> int:
    ...

Examples

Example 1:

matrix = [
    [1, 5, 9],
    [10, 11, 13],
    [12, 13, 15],
]
k = 8

If we flatten and sort the matrix, we get:

[1, 5, 9, 10, 11, 12, 13, 13, 15]

The 8th smallest value is:

13

Example 2:

matrix = [[-5]]
k = 1

There is only one value, so the answer is:

-5

First Thought: Flatten and Sort

The simplest idea is to put every value into one array, sort it, and return index k - 1.

class Solution:
    def kthSmallest(self, matrix: list[list[int]], k: int) -> int:
        values = []

        for row in matrix:
            for x in row:
                values.append(x)

        values.sort()
        return values[k - 1]

This works, but it stores all n^2 values.

The problem asks for memory better than O(n^2), so we should use the sorted structure of the matrix.

Key Insight

The matrix is sorted by rows and columns.

Instead of binary searching over positions, we binary search over values.

The smallest possible answer is:

matrix[0][0]

The largest possible answer is:

matrix[n - 1][n - 1]

For any value mid, we can count how many numbers in the matrix are less than or equal to mid.

If at least k numbers are <= mid, then the answer is at most mid.

If fewer than k numbers are <= mid, then the answer is larger than mid.

So the problem becomes:

Find the smallest value x such that at least k matrix elements are <= x.

Counting Values Less Than or Equal to mid

We can count in O(n) time using the bottom-left corner.

Start at:

row = n - 1
col = 0

At each step:

If:

matrix[row][col] <= mid

then every value above it in the same column is also <= mid, because the column is sorted.

So we add:

row + 1

and move right.

Otherwise, the current value is too large, so we move up.

Algorithm

Set:

left = matrix[0][0]
right = matrix[n - 1][n - 1]

Then binary search while left < right.

For each mid:

  1. Count how many values in the matrix are <= mid.
  2. If the count is at least k, move right to mid.
  3. Otherwise, move left to mid + 1.

At the end, left is the smallest value with at least k values less than or equal to it.

That value is the answer.

Correctness

Define count(x) as the number of matrix elements less than or equal to x.

As x increases, count(x) never decreases.

So the condition:

count(x) >= k

is monotonic.

For small values, the condition may be false.

For large values, the condition becomes true.

The kth smallest value is exactly the first value where this condition becomes true.

The binary search keeps this invariant:

BoundMeaning
leftThe answer cannot be smaller than this
rightThe answer can still be this value or smaller

When count(mid) >= k, there are already enough values less than or equal to mid, so the answer is mid or smaller.

When count(mid) < k, not enough values are less than or equal to mid, so the answer must be larger.

When the search ends, left == right, and this value is the smallest value whose count is at least k.

Therefore it is the kth smallest matrix element.

Complexity

Let n be the matrix size.

Let R = matrix[n - 1][n - 1] - matrix[0][0].

MetricValueWhy
TimeO(n log R)Each binary search step counts in O(n)
SpaceO(1)Only a few variables are used

Implementation

class Solution:
    def kthSmallest(self, matrix: list[list[int]], k: int) -> int:
        n = len(matrix)

        def count_less_equal(x: int) -> int:
            row = n - 1
            col = 0
            count = 0

            while row >= 0 and col < n:
                if matrix[row][col] <= x:
                    count += row + 1
                    col += 1
                else:
                    row -= 1

            return count

        left = matrix[0][0]
        right = matrix[n - 1][n - 1]

        while left < right:
            mid = (left + right) // 2

            if count_less_equal(mid) >= k:
                right = mid
            else:
                left = mid + 1

        return left

Code Explanation

The helper function counts how many values are less than or equal to x.

def count_less_equal(x: int) -> int:

We start from the bottom-left corner:

row = n - 1
col = 0

If the current value is small enough:

if matrix[row][col] <= x:

then all values above it in the same column are also small enough.

So we add:

count += row + 1

and move to the next column:

col += 1

If the current value is too large, we move upward:

row -= 1

The binary search checks values, not indices:

while left < right:

If there are at least k values less than or equal to mid, we keep the left half:

right = mid

Otherwise, the answer must be larger:

left = mid + 1

Finally:

return left

Testing

def test_solution():
    s = Solution()

    assert s.kthSmallest(
        [
            [1, 5, 9],
            [10, 11, 13],
            [12, 13, 15],
        ],
        8,
    ) == 13

    assert s.kthSmallest([[-5]], 1) == -5

    assert s.kthSmallest(
        [
            [1, 2],
            [1, 3],
        ],
        2,
    ) == 1

    assert s.kthSmallest(
        [
            [1, 2],
            [3, 4],
        ],
        4,
    ) == 4

    assert s.kthSmallest(
        [
            [-5, -4],
            [-3, -1],
        ],
        3,
    ) == -3

    print("all tests passed")

test_solution()

Test meaning:

TestWhy
Main 3 x 3 exampleChecks duplicate 13 is counted separately
Single elementMinimum matrix size
Duplicate valuesConfirms kth means sorted order, not distinct order
k = n^2Returns the largest value
Negative valuesConfirms binary search works across negative ranges