A clear guide to searching a sorted 2D matrix using binary search over a virtual one-dimensional array.
Problem Restatement
We are given an m x n integer matrix and an integer target.
The matrix has two important properties:
- Each row is sorted in non-decreasing order.
- The first integer of each row is greater than the last integer of the previous row.
We need to return true if target exists in the matrix. Otherwise, return false.
The required time complexity is O(log(m * n)). The official constraints are 1 <= m, n <= 100 and -10^4 <= matrix[i][j], target <= 10^4.
Input and Output
| Item | Meaning |
|---|---|
| Input | A sorted 2D matrix and an integer target |
| Output | true if target exists, otherwise false |
| Row rule | Each row is sorted |
| Matrix rule | First value of each row is greater than the last value of the previous row |
| Required time | O(log(m * n)) |
Function shape:
def searchMatrix(matrix: list[list[int]], target: int) -> bool:
...Examples
For:
matrix = [
[1, 3, 5, 7],
[10, 11, 16, 20],
[23, 30, 34, 60],
]
target = 3The answer is:
TrueThe value 3 appears in the first row.
For:
matrix = [
[1, 3, 5, 7],
[10, 11, 16, 20],
[23, 30, 34, 60],
]
target = 13The answer is:
FalseThe value 13 does not appear anywhere in the matrix.
First Thought: Scan Every Cell
The simplest solution is to check every cell.
class Solution:
def searchMatrix(
self,
matrix: list[list[int]],
target: int,
) -> bool:
for row in matrix:
for value in row:
if value == target:
return True
return FalseThis is correct, but it takes:
O(m * n)The problem asks for O(log(m * n)), so we need binary search.
Key Insight
The matrix behaves like one sorted array if we read it row by row.
For example:
[
[1, 3, 5, 7],
[10, 11, 16, 20],
[23, 30, 34, 60],
]is logically the same as:
[1, 3, 5, 7, 10, 11, 16, 20, 23, 30, 34, 60]We do not need to build this flattened array.
We can binary search over virtual indices from:
0to:
m * n - 1For a virtual index mid, convert it back to matrix coordinates:
row = mid // n
col = mid % nThis works because each row has exactly n columns.
Index Mapping
Suppose:
n = 4Then the virtual indices map like this:
| Virtual index | Matrix coordinate |
|---|---|
0 | (0, 0) |
1 | (0, 1) |
2 | (0, 2) |
3 | (0, 3) |
4 | (1, 0) |
5 | (1, 1) |
8 | (2, 0) |
The formula is:
row = index // n
col = index % nSo we can run normal binary search without changing the matrix.
Algorithm
Let:
m = len(matrix)
n = len(matrix[0])Set:
left = 0
right = m * n - 1While left <= right:
- Compute
mid. - Convert
midto(row, col). - Compare
matrix[row][col]withtarget. - If equal, return
True. - If smaller than
target, search the right half. - If larger than
target, search the left half.
If the loop ends, return False.
Correctness
Because each row is sorted and the first element of each row is greater than the last element of the previous row, reading the matrix from left to right and top to bottom gives one globally sorted sequence.
The algorithm performs binary search on this virtual sorted sequence. For every virtual index, the mapping row = index // n and col = index % n returns exactly the corresponding matrix cell in row-major order.
When the middle value is smaller than target, every value before it in the virtual sequence is also smaller, so the algorithm safely discards the left half. When the middle value is larger than target, every value after it is also larger, so the algorithm safely discards the right half.
If target exists, binary search eventually examines its virtual index and returns True. If the search range becomes empty, no possible position remains, so returning False is correct.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(log(m * n)) | Binary search over m * n virtual positions |
| Space | O(1) | Only index variables are stored |
Implementation
class Solution:
def searchMatrix(
self,
matrix: list[list[int]],
target: int,
) -> bool:
m = len(matrix)
n = len(matrix[0])
left = 0
right = m * n - 1
while left <= right:
mid = left + (right - left) // 2
row = mid // n
col = mid % n
value = matrix[row][col]
if value == target:
return True
if value < target:
left = mid + 1
else:
right = mid - 1
return FalseCode Explanation
First read the matrix size:
m = len(matrix)
n = len(matrix[0])The virtual sorted array has this many elements:
m * nSet the binary search boundaries:
left = 0
right = m * n - 1Compute the middle index:
mid = left + (right - left) // 2Convert the virtual index to matrix coordinates:
row = mid // n
col = mid % nRead the matrix value:
value = matrix[row][col]If it matches, return immediately:
if value == target:
return TrueIf the value is too small, search the right half:
if value < target:
left = mid + 1Otherwise, search the left half:
else:
right = mid - 1If the loop ends, the target was not found:
return FalseTesting
def run_tests():
s = Solution()
matrix = [
[1, 3, 5, 7],
[10, 11, 16, 20],
[23, 30, 34, 60],
]
assert s.searchMatrix(matrix, 3) is True
assert s.searchMatrix(matrix, 13) is False
assert s.searchMatrix(matrix, 1) is True
assert s.searchMatrix(matrix, 60) is True
assert s.searchMatrix(matrix, 0) is False
assert s.searchMatrix(matrix, 61) is False
assert s.searchMatrix([[1]], 1) is True
assert s.searchMatrix([[1]], 2) is False
assert s.searchMatrix([[1, 3, 5]], 3) is True
assert s.searchMatrix([[1], [3], [5]], 3) is True
print("all tests passed")
run_tests()| Test | Why |
|---|---|
| Target in middle | Normal successful search |
| Missing target | Normal failed search |
| First element | Left boundary |
| Last element | Right boundary |
| Below minimum | Out of range low |
| Above maximum | Out of range high |
1 x 1 matrix | Smallest matrix |
| Single row | Row-only search |
| Single column | Column-only search |