# LeetCode 378: Kth Smallest Element in a Sorted Matrix

## Problem Restatement

We are given an `n x n` matrix.

Each row is sorted in non-decreasing order.

Each column is also sorted in non-decreasing order.

We need to return the `k`th smallest element in the whole matrix. Duplicates count separately, so this means the `k`th element in the fully sorted list, not the `k`th distinct value. The problem also asks for memory better than `O(n^2)`.

## Input and Output

| Item | Meaning |
|---|---|
| Input | A sorted `n x n` matrix and an integer `k` |
| Output | The `k`th smallest value |
| Matrix property | Rows and columns are sorted |
| Duplicate values | Count separately |
| Constraint | `1 <= n <= 300` |
| Constraint | `-10^9 <= matrix[i][j] <= 10^9` |
| Constraint | `1 <= k <= n^2` |

Example function shape:

```python
def kthSmallest(matrix: list[list[int]], k: int) -> int:
    ...
```

## Examples

Example 1:

```python
matrix = [
    [1, 5, 9],
    [10, 11, 13],
    [12, 13, 15],
]
k = 8
```

If we flatten and sort the matrix, we get:

```python
[1, 5, 9, 10, 11, 12, 13, 13, 15]
```

The 8th smallest value is:

```python
13
```

Example 2:

```python
matrix = [[-5]]
k = 1
```

There is only one value, so the answer is:

```python
-5
```

## First Thought: Flatten and Sort

The simplest idea is to put every value into one array, sort it, and return index `k - 1`.

```python
class Solution:
    def kthSmallest(self, matrix: list[list[int]], k: int) -> int:
        values = []

        for row in matrix:
            for x in row:
                values.append(x)

        values.sort()
        return values[k - 1]
```

This works, but it stores all `n^2` values.

The problem asks for memory better than `O(n^2)`, so we should use the sorted structure of the matrix.

## Key Insight

The matrix is sorted by rows and columns.

Instead of binary searching over positions, we binary search over values.

The smallest possible answer is:

```python
matrix[0][0]
```

The largest possible answer is:

```python
matrix[n - 1][n - 1]
```

For any value `mid`, we can count how many numbers in the matrix are less than or equal to `mid`.

If at least `k` numbers are `<= mid`, then the answer is at most `mid`.

If fewer than `k` numbers are `<= mid`, then the answer is larger than `mid`.

So the problem becomes:

Find the smallest value `x` such that at least `k` matrix elements are `<= x`.

## Counting Values Less Than or Equal to `mid`

We can count in `O(n)` time using the bottom-left corner.

Start at:

```python
row = n - 1
col = 0
```

At each step:

If:

```python
matrix[row][col] <= mid
```

then every value above it in the same column is also `<= mid`, because the column is sorted.

So we add:

```python
row + 1
```

and move right.

Otherwise, the current value is too large, so we move up.

## Algorithm

Set:

```python
left = matrix[0][0]
right = matrix[n - 1][n - 1]
```

Then binary search while `left < right`.

For each `mid`:

1. Count how many values in the matrix are `<= mid`.
2. If the count is at least `k`, move `right` to `mid`.
3. Otherwise, move `left` to `mid + 1`.

At the end, `left` is the smallest value with at least `k` values less than or equal to it.

That value is the answer.

## Correctness

Define `count(x)` as the number of matrix elements less than or equal to `x`.

As `x` increases, `count(x)` never decreases.

So the condition:

```python
count(x) >= k
```

is monotonic.

For small values, the condition may be false.

For large values, the condition becomes true.

The `k`th smallest value is exactly the first value where this condition becomes true.

The binary search keeps this invariant:

| Bound | Meaning |
|---|---|
| `left` | The answer cannot be smaller than this |
| `right` | The answer can still be this value or smaller |

When `count(mid) >= k`, there are already enough values less than or equal to `mid`, so the answer is `mid` or smaller.

When `count(mid) < k`, not enough values are less than or equal to `mid`, so the answer must be larger.

When the search ends, `left == right`, and this value is the smallest value whose count is at least `k`.

Therefore it is the `k`th smallest matrix element.

## Complexity

Let `n` be the matrix size.

Let `R = matrix[n - 1][n - 1] - matrix[0][0]`.

| Metric | Value | Why |
|---|---|---|
| Time | `O(n log R)` | Each binary search step counts in `O(n)` |
| Space | `O(1)` | Only a few variables are used |

## Implementation

```python
class Solution:
    def kthSmallest(self, matrix: list[list[int]], k: int) -> int:
        n = len(matrix)

        def count_less_equal(x: int) -> int:
            row = n - 1
            col = 0
            count = 0

            while row >= 0 and col < n:
                if matrix[row][col] <= x:
                    count += row + 1
                    col += 1
                else:
                    row -= 1

            return count

        left = matrix[0][0]
        right = matrix[n - 1][n - 1]

        while left < right:
            mid = (left + right) // 2

            if count_less_equal(mid) >= k:
                right = mid
            else:
                left = mid + 1

        return left
```

## Code Explanation

The helper function counts how many values are less than or equal to `x`.

```python
def count_less_equal(x: int) -> int:
```

We start from the bottom-left corner:

```python
row = n - 1
col = 0
```

If the current value is small enough:

```python
if matrix[row][col] <= x:
```

then all values above it in the same column are also small enough.

So we add:

```python
count += row + 1
```

and move to the next column:

```python
col += 1
```

If the current value is too large, we move upward:

```python
row -= 1
```

The binary search checks values, not indices:

```python
while left < right:
```

If there are at least `k` values less than or equal to `mid`, we keep the left half:

```python
right = mid
```

Otherwise, the answer must be larger:

```python
left = mid + 1
```

Finally:

```python
return left
```

## Testing

```python
def test_solution():
    s = Solution()

    assert s.kthSmallest(
        [
            [1, 5, 9],
            [10, 11, 13],
            [12, 13, 15],
        ],
        8,
    ) == 13

    assert s.kthSmallest([[-5]], 1) == -5

    assert s.kthSmallest(
        [
            [1, 2],
            [1, 3],
        ],
        2,
    ) == 1

    assert s.kthSmallest(
        [
            [1, 2],
            [3, 4],
        ],
        4,
    ) == 4

    assert s.kthSmallest(
        [
            [-5, -4],
            [-3, -1],
        ],
        3,
    ) == -3

    print("all tests passed")

test_solution()
```

Test meaning:

| Test | Why |
|---|---|
| Main `3 x 3` example | Checks duplicate `13` is counted separately |
| Single element | Minimum matrix size |
| Duplicate values | Confirms kth means sorted order, not distinct order |
| `k = n^2` | Returns the largest value |
| Negative values | Confirms binary search works across negative ranges |

