A clear explanation of finding the k smallest pair sums from two sorted arrays using a min heap and best-first search.
Problem Restatement
We are given two sorted integer arrays:
nums1
nums2We define a pair:
(u, v)where:
| Value | Comes from |
|---|---|
u | nums1 |
v | nums2 |
The pair sum is:
u + vWe need to return the k pairs with the smallest sums.
The arrays are already sorted in ascending order.
The official example is:
nums1 = [1, 7, 11]
nums2 = [2, 4, 6]
k = 3Output:
[[1, 2], [1, 4], [1, 6]]because these are the three smallest pair sums. (leetcode.com)
Input and Output
| Item | Meaning |
|---|---|
| Input | Two sorted arrays and integer k |
| Output | k pairs with smallest sums |
| Pair form | [nums1[i], nums2[j]] |
| Arrays | Sorted ascending |
| Valid answers | Any order with correct pairs |
Example function shape:
def kSmallestPairs(
nums1: list[int],
nums2: list[int],
k: int,
) -> list[list[int]]:
...Examples
Example 1:
nums1 = [1, 7, 11]
nums2 = [2, 4, 6]
k = 3All possible pairs:
| Pair | Sum |
|---|---|
(1, 2) | 3 |
(1, 4) | 5 |
(1, 6) | 7 |
(7, 2) | 9 |
(7, 4) | 11 |
(11, 2) | 13 |
(7, 6) | 13 |
(11, 4) | 15 |
(11, 6) | 17 |
The smallest 3 sums are:
[[1, 2], [1, 4], [1, 6]]Example 2:
nums1 = [1, 1, 2]
nums2 = [1, 2, 3]
k = 2The two smallest sums are:
[[1, 1], [1, 1]]because there are two different 1s in nums1.
First Thought: Generate All Pairs
A direct solution is:
- Generate every pair.
- Compute every sum.
- Sort all pairs by sum.
- Return the first
k.
If:
| Array | Size |
|---|---|
nums1 | m |
nums2 | n |
then there are:
m * npairs.
Sorting them costs:
O(m * n log(m * n))This becomes too large for big arrays.
We need to avoid generating every pair.
Key Insight
The arrays are sorted.
Suppose:
nums1 = [1, 7, 11]
nums2 = [2, 4, 6]Think of the pair sums as a matrix:
2 | 4 | 6 | |
|---|---|---|---|
1 | 3 | 5 | 7 |
7 | 9 | 11 | 13 |
11 | 13 | 15 | 17 |
Rows increase left to right.
Columns increase top to bottom.
This is similar to merging sorted lists.
For a fixed index i in nums1, the pairs:
(nums1[i], nums2[0])
(nums1[i], nums2[1])
(nums1[i], nums2[2])
...appear in increasing sum order.
So:
- Start with the smallest pair from each row.
- Always take the globally smallest available pair.
- After taking
(i, j), push(i, j + 1)from the same row.
This is a classic min-heap merge pattern.
Algorithm
Use a min heap storing:
(sum, i, j)where:
| Field | Meaning |
|---|---|
sum | nums1[i] + nums2[j] |
i | Index in nums1 |
j | Index in nums2 |
Initialization:
Push:
(i, 0)for the first:
min(k, len(nums1))rows.
Why only those rows?
Because we only need the smallest k pairs overall.
Then:
- Pop the smallest pair from the heap.
- Add it to the answer.
- If
(i, j + 1)exists, push it.
Repeat until:
k == 0or the heap becomes empty.
Correctness
For each fixed index i, the pairs:
(nums1[i], nums2[0]),
(nums1[i], nums2[1]),
(nums1[i], nums2[2]),
...have non-decreasing sums because nums2 is sorted.
Therefore, each row forms a sorted list of pair sums.
The heap always stores the smallest not-yet-used pair from each active row.
Initially, the heap contains the smallest pair from every relevant row, so the global minimum pair is in the heap.
When the algorithm removes (i, j), the next candidate from that same row is (i, j + 1). Since rows are sorted, no smaller unused pair from row i exists.
Thus the heap invariant is preserved: it always contains the smallest remaining candidate from each row.
At every step, the heap top is therefore the smallest unused pair globally.
The algorithm outputs pairs in increasing sum order until k pairs are produced or all pairs are exhausted.
Therefore, the returned pairs are exactly the k smallest-sum pairs.
Complexity
Let:
| Symbol | Meaning |
|---|---|
m | len(nums1) |
n | len(nums2) |
The heap size never exceeds:
min(k, m)Each pop may add one push.
We perform at most k heap removals.
| Metric | Value |
|---|---|
| Time | O(k log(min(k, m))) |
| Space | O(min(k, m)) |
This is much better than generating all m * n pairs.
Implementation
import heapq
class Solution:
def kSmallestPairs(
self,
nums1: list[int],
nums2: list[int],
k: int,
) -> list[list[int]]:
if not nums1 or not nums2 or k == 0:
return []
heap = []
for i in range(min(k, len(nums1))):
heapq.heappush(
heap,
(nums1[i] + nums2[0], i, 0),
)
answer = []
while heap and k > 0:
_, i, j = heapq.heappop(heap)
answer.append([nums1[i], nums2[j]])
if j + 1 < len(nums2):
heapq.heappush(
heap,
(
nums1[i] + nums2[j + 1],
i,
j + 1,
),
)
k -= 1
return answerCode Explanation
We handle empty-array cases first:
if not nums1 or not nums2 or k == 0:
return []The heap stores:
(sum, i, j)We begin with the first pair from each row:
(nums1[i], nums2[0])because these are the smallest pairs in each row.
Initialization:
for i in range(min(k, len(nums1))):We only need at most k rows because the answer contains at most k pairs.
Each loop iteration removes the current smallest pair:
_, i, j = heapq.heappop(heap)Then append it:
answer.append([nums1[i], nums2[j]])The next candidate in the same row is:
(i, j + 1)So we push it if it exists:
if j + 1 < len(nums2):Finally:
k -= 1because one answer pair has been produced.
Testing
def run_tests():
s = Solution()
assert s.kSmallestPairs(
[1, 7, 11],
[2, 4, 6],
3,
) == [[1, 2], [1, 4], [1, 6]]
assert s.kSmallestPairs(
[1, 1, 2],
[1, 2, 3],
2,
) == [[1, 1], [1, 1]]
assert s.kSmallestPairs(
[1, 2],
[3],
3,
) == [[1, 3], [2, 3]]
assert s.kSmallestPairs(
[],
[1, 2],
3,
) == []
assert s.kSmallestPairs(
[1],
[2],
1,
) == [[1, 2]]
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
| Standard example | Checks basic heap merge |
| Duplicate values | Confirms duplicate pairs are allowed |
k larger than total pairs | Returns all pairs |
| Empty array | No possible pairs |
| Single-element arrays | Minimum non-empty case |