A clear explanation of counting pairs where nums[i] is greater than twice nums[j] using merge sort.
Problem Restatement
We are given an integer array nums.
We need to count pairs of indices (i, j) such that:
0 <= i < j < nums.lengthand:
nums[i] > 2 * nums[j]These are called reverse pairs. The input length can be up to 5 * 10^4, and values may be negative or as large as 32-bit signed integers.
Input and Output
| Item | Meaning |
|---|---|
| Input | Integer array nums |
| Output | Number of reverse pairs |
| Pair rule | i < j |
| Value rule | nums[i] > 2 * nums[j] |
Function shape:
def reversePairs(nums: list[int]) -> int:
...Examples
Example 1:
nums = [1, 3, 2, 3, 1]The valid reverse pairs are:
(1, 4): nums[1] = 3, nums[4] = 1, 3 > 2 * 1
(3, 4): nums[3] = 3, nums[4] = 1, 3 > 2 * 1Answer:
2Example 2:
nums = [2, 4, 3, 5, 1]The valid reverse pairs are:
(1, 4): 4 > 2 * 1
(2, 4): 3 > 2 * 1
(3, 4): 5 > 2 * 1Answer:
3First Thought: Check Every Pair
The direct solution is to check every pair (i, j).
class Solution:
def reversePairs(self, nums: list[int]) -> int:
ans = 0
n = len(nums)
for i in range(n):
for j in range(i + 1, n):
if nums[i] > 2 * nums[j]:
ans += 1
return ansThis is correct, but it is too slow.
Problem With Brute Force
There are about:
n * (n - 1) / 2pairs.
So the brute force solution costs:
O(n^2)For n = 50000, this is too large.
We need to count many pairs at once.
Key Insight
This problem is similar to counting inversions.
The condition:
nums[i] > 2 * nums[j]compares one value on the left with one value on the right.
Merge sort is useful because after sorting the left half and right half, we can count cross pairs efficiently.
Suppose the left half and right half are already sorted:
left = [2, 3, 4]
right = [1, 5]For each number in the left half, we want to know how many numbers in the right half satisfy:
left_value > 2 * right_valueBecause right is sorted, once a right[j] is too large, all later right values are also too large.
So we can use two pointers and count in linear time.
Counting Cross Pairs
During merge sort, after solving the left and right halves, we count pairs where:
i is in the left half
j is in the right halfSince all left indices come before all right indices, the condition i < j is automatically satisfied.
For each i in the left half, move pointer j in the right half while:
nums[i] > 2 * nums[j]Then every right index before j forms a reverse pair with i.
So we add:
j - (mid + 1)to the answer.
Important: count before merging, while both halves are sorted individually.
Algorithm
Use merge sort.
For a range nums[left:right + 1]:
- Split it into two halves.
- Recursively count reverse pairs in the left half.
- Recursively count reverse pairs in the right half.
- Count reverse pairs crossing from left half to right half.
- Merge the two sorted halves.
- Return the total count.
Correctness
Every reverse pair (i, j) belongs to exactly one of three groups:
| Case | Where the pair is counted |
|---|---|
| Both indices in the left half | Recursive call on left half |
| Both indices in the right half | Recursive call on right half |
i in left half and j in right half | Cross-pair counting step |
These groups are disjoint and cover all possible pairs.
For cross pairs, both halves are sorted before counting. For a fixed left value nums[i], the algorithm advances j through the right half while nums[i] > 2 * nums[j].
All right elements before j satisfy the condition. All right elements from j onward fail the condition because the right half is sorted.
Therefore, the algorithm counts exactly the number of valid cross pairs for each i.
After counting, the algorithm merges the two halves so the parent call also receives a sorted range.
By induction over merge sort ranges, every range returns the correct number of reverse pairs inside it. Therefore, the full call returns the number of reverse pairs in the whole array.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n log n) | Merge sort has log n levels, and each level scans all elements |
| Space | O(n) | Temporary arrays are used for merging |
Implementation
class Solution:
def reversePairs(self, nums: list[int]) -> int:
def sort_and_count(left: int, right: int) -> int:
if left >= right:
return 0
mid = (left + right) // 2
count = sort_and_count(left, mid)
count += sort_and_count(mid + 1, right)
j = mid + 1
for i in range(left, mid + 1):
while j <= right and nums[i] > 2 * nums[j]:
j += 1
count += j - (mid + 1)
merged = []
i = left
j = mid + 1
while i <= mid and j <= right:
if nums[i] <= nums[j]:
merged.append(nums[i])
i += 1
else:
merged.append(nums[j])
j += 1
while i <= mid:
merged.append(nums[i])
i += 1
while j <= right:
merged.append(nums[j])
j += 1
nums[left:right + 1] = merged
return count
return sort_and_count(0, len(nums) - 1)Code Explanation
The base case has no pair:
if left >= right:
return 0We split the current range:
mid = (left + right) // 2Then count pairs inside both halves:
count = sort_and_count(left, mid)
count += sort_and_count(mid + 1, right)Now both halves are sorted.
This part counts cross pairs:
j = mid + 1
for i in range(left, mid + 1):
while j <= right and nums[i] > 2 * nums[j]:
j += 1
count += j - (mid + 1)The pointer j never moves backward, so this whole counting loop is linear in the range size.
Then we merge the two sorted halves:
nums[left:right + 1] = mergedThis keeps the array sorted for upper recursion levels.
Testing
def run_tests():
s = Solution()
assert s.reversePairs([1, 3, 2, 3, 1]) == 2
assert s.reversePairs([2, 4, 3, 5, 1]) == 3
assert s.reversePairs([1, 2, 3, 4, 5]) == 0
assert s.reversePairs([5, 4, 3, 2, 1]) == 4
assert s.reversePairs([-5, -5]) == 1
assert s.reversePairs([2147483647, 1073741823, -2147483648]) == 3
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
[1, 3, 2, 3, 1] | Main example |
[2, 4, 3, 5, 1] | Main example with three pairs |
| Increasing array | No reverse pairs |
| Decreasing array | Some, but not all, pairs satisfy > 2 * |
| Negative values | Checks that the inequality still works with negatives |
| Large 32-bit values | Checks boundary-sized integers |