A clear explanation of Count of Range Sum using prefix sums and merge sort counting.
Problem Restatement
We are given an integer array nums and two integers lower and upper.
A range sum is the sum of a contiguous subarray:
S(i, j) = nums[i] + nums[i + 1] + ... + nums[j]where i <= j.
We need to count how many range sums lie inside this inclusive range:
lower <= S(i, j) <= upperThe naive O(n^2) solution is too slow, so we need something better.
Input and Output
| Item | Meaning |
|---|---|
| Input | nums, lower, and upper |
| Output | Number of subarray sums in [lower, upper] |
| Constraints | nums.length can be up to 10^5 |
| Required idea | Better than O(n^2) |
Function shape:
def countRangeSum(nums: list[int], lower: int, upper: int) -> int:
...Examples
Example 1:
Input: nums = [-2, 5, -1], lower = -2, upper = 2
Output: 3The valid ranges are:
| Range | Sum |
|---|---|
[0, 0] | -2 |
[2, 2] | -1 |
[0, 2] | 2 |
So the answer is 3.
Example 2:
Input: nums = [0], lower = 0, upper = 0
Output: 1There is only one subarray: [0].
Its sum is 0, which lies in [0, 0].
First Thought: Brute Force
The direct method is to enumerate every subarray.
For each starting index i, keep extending the ending index j, and update the running sum.
class Solution:
def countRangeSum(self, nums: list[int], lower: int, upper: int) -> int:
ans = 0
for i in range(len(nums)):
total = 0
for j in range(i, len(nums)):
total += nums[j]
if lower <= total <= upper:
ans += 1
return ansThis gives the right answer, but it checks all subarrays.
Problem With Brute Force
There are about n^2 / 2 subarrays.
So the time complexity is:
O(n^2)For n = 100000, this is far too large.
We need to count valid ranges without explicitly checking every subarray.
Key Insight
Use prefix sums.
Define:
prefix[0] = 0
prefix[i + 1] = nums[0] + nums[1] + ... + nums[i]Then the sum of subarray nums[i:j] is:
prefix[j] - prefix[i]For an inclusive subarray nums[i] through nums[j], the range sum is:
prefix[j + 1] - prefix[i]So the problem becomes:
Count pairs (i, j) where i < j and:
lower <= prefix[j] - prefix[i] <= upperRearrange this:
prefix[j] - upper <= prefix[i] <= prefix[j] - lowerFor each right prefix sum prefix[j], we need to know how many earlier prefix sums fall inside a value interval.
That is the real problem.
Why Merge Sort Helps
During merge sort, we split the prefix sums into two halves.
For every prefix sum in the left half, we count how many prefix sums in the right half produce a valid difference.
Because both halves are sorted, we can count these pairs with two moving pointers.
Suppose left_value is from the left half.
We need right values x such that:
lower <= x - left_value <= upperRearrange:
left_value + lower <= x <= left_value + upperSince the right half is sorted, all valid x values form one contiguous window.
We can move two pointers:
lo: first right value >= left_value + lower
hi: first right value > left_value + upperThen the number of valid right values is:
hi - loThis gives linear counting during each merge level.
Algorithm
Build the prefix sum array.
Then run a recursive merge sort function over the prefix sums.
For each recursive range:
- Split into left half and right half.
- Count valid pairs fully inside the left half.
- Count valid pairs fully inside the right half.
- Count valid pairs crossing from left half to right half.
- Sort the current range before returning.
The crossing count is the important part.
For each prefix sum x in the left sorted half:
x + lower <= y <= x + upperwhere y comes from the right sorted half.
Use two pointers to count all such y.
Walkthrough
Take:
nums = [-2, 5, -1]
lower = -2
upper = 2Build prefix sums:
prefix = [0, -2, 3, 2]Every range sum is a difference between two prefix sums.
The valid pairs are:
| Pair of prefix indices | Difference | Subarray |
|---|---|---|
(0, 1) | -2 - 0 = -2 | [0, 0] |
(2, 3) | 2 - 3 = -1 | [2, 2] |
(0, 3) | 2 - 0 = 2 | [0, 2] |
So the answer is 3.
Merge sort counts these pairs while sorting the prefix array.
Correctness
Every subarray sum can be represented as the difference of two prefix sums.
For any subarray from index i to index j, its sum is:
prefix[j + 1] - prefix[i]So counting valid range sums is exactly the same as counting prefix pairs (i, j) with i < j and:
lower <= prefix[j] - prefix[i] <= upperThe recursive algorithm counts all such pairs in three disjoint groups:
| Group | Meaning |
|---|---|
| Left half | Both prefix indices are in the left half |
| Right half | Both prefix indices are in the right half |
| Crossing | First prefix index is in the left half, second prefix index is in the right half |
The recursive calls correctly count the left and right groups.
For crossing pairs, the left and right halves are sorted. For each left value x, all right values y satisfying:
x + lower <= y <= x + upperform a contiguous interval in the sorted right half.
The two pointers find exactly the bounds of this interval. Therefore hi - lo counts exactly the valid crossing pairs for x.
Since every valid pair belongs to exactly one of the three groups, and each group is counted exactly once, the algorithm returns the correct count.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n log n) | Merge sort has log n levels, and each level does linear work |
| Space | O(n) | Prefix sums and merged arrays require linear space |
The prefix array has length n + 1, so the complexity is still O(n log n).
Implementation
class Solution:
def countRangeSum(self, nums: list[int], lower: int, upper: int) -> int:
prefix = [0]
for num in nums:
prefix.append(prefix[-1] + num)
def sort_and_count(arr: list[int]) -> tuple[list[int], int]:
n = len(arr)
if n <= 1:
return arr, 0
mid = n // 2
left, left_count = sort_and_count(arr[:mid])
right, right_count = sort_and_count(arr[mid:])
count = left_count + right_count
lo = 0
hi = 0
for x in left:
while lo < len(right) and right[lo] < x + lower:
lo += 1
while hi < len(right) and right[hi] <= x + upper:
hi += 1
count += hi - lo
merged = []
i = 0
j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
merged.append(left[i])
i += 1
else:
merged.append(right[j])
j += 1
merged.extend(left[i:])
merged.extend(right[j:])
return merged, count
_, answer = sort_and_count(prefix)
return answerCode Explanation
We first build prefix sums:
prefix = [0]
for num in nums:
prefix.append(prefix[-1] + num)The initial 0 represents the sum before taking any elements. This is needed for subarrays that start at index 0.
The recursive function returns two things:
return sorted_array, countThe sorted array is needed by the parent call.
The count is the number of valid range sums inside that array segment.
The base case has one prefix sum:
if n <= 1:
return arr, 0A single prefix sum cannot form a pair, so the count is 0.
Then we split and recurse:
left, left_count = sort_and_count(arr[:mid])
right, right_count = sort_and_count(arr[mid:])The recursive calls count pairs fully inside each half.
Then we count crossing pairs.
For each x in left, valid right values must satisfy:
x + lower <= y <= x + upperThe first pointer skips values that are too small:
while lo < len(right) and right[lo] < x + lower:
lo += 1The second pointer includes values that are still small enough:
while hi < len(right) and right[hi] <= x + upper:
hi += 1So this many values are valid:
count += hi - loFinally, we merge the two sorted halves so the parent level can use them.
Testing
def run_tests():
s = Solution()
assert s.countRangeSum([-2, 5, -1], -2, 2) == 3
assert s.countRangeSum([0], 0, 0) == 1
assert s.countRangeSum([1], 0, 0) == 0
assert s.countRangeSum([1, -1], 0, 0) == 1
assert s.countRangeSum([1, 2, 3], 3, 5) == 3
assert s.countRangeSum([-1, -1, -1], -2, -1) == 5
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
[-2, 5, -1], [-2, 2] | Official-style example |
[0], [0, 0] | Single zero forms one valid range |
[1], [0, 0] | Single value outside range |
[1, -1], [0, 0] | Valid zero-sum range |
[1, 2, 3], [3, 5] | Multiple positive ranges |
[-1, -1, -1], [-2, -1] | Negative sums and repeated values |