Skip to content

LeetCode 327: Count of Range Sum

A clear explanation of Count of Range Sum using prefix sums and merge sort counting.

Problem Restatement

We are given an integer array nums and two integers lower and upper.

A range sum is the sum of a contiguous subarray:

S(i, j) = nums[i] + nums[i + 1] + ... + nums[j]

where i <= j.

We need to count how many range sums lie inside this inclusive range:

lower <= S(i, j) <= upper

The naive O(n^2) solution is too slow, so we need something better.

Input and Output

ItemMeaning
Inputnums, lower, and upper
OutputNumber of subarray sums in [lower, upper]
Constraintsnums.length can be up to 10^5
Required ideaBetter than O(n^2)

Function shape:

def countRangeSum(nums: list[int], lower: int, upper: int) -> int:
    ...

Examples

Example 1:

Input: nums = [-2, 5, -1], lower = -2, upper = 2
Output: 3

The valid ranges are:

RangeSum
[0, 0]-2
[2, 2]-1
[0, 2]2

So the answer is 3.

Example 2:

Input: nums = [0], lower = 0, upper = 0
Output: 1

There is only one subarray: [0].

Its sum is 0, which lies in [0, 0].

First Thought: Brute Force

The direct method is to enumerate every subarray.

For each starting index i, keep extending the ending index j, and update the running sum.

class Solution:
    def countRangeSum(self, nums: list[int], lower: int, upper: int) -> int:
        ans = 0

        for i in range(len(nums)):
            total = 0

            for j in range(i, len(nums)):
                total += nums[j]

                if lower <= total <= upper:
                    ans += 1

        return ans

This gives the right answer, but it checks all subarrays.

Problem With Brute Force

There are about n^2 / 2 subarrays.

So the time complexity is:

O(n^2)

For n = 100000, this is far too large.

We need to count valid ranges without explicitly checking every subarray.

Key Insight

Use prefix sums.

Define:

prefix[0] = 0
prefix[i + 1] = nums[0] + nums[1] + ... + nums[i]

Then the sum of subarray nums[i:j] is:

prefix[j] - prefix[i]

For an inclusive subarray nums[i] through nums[j], the range sum is:

prefix[j + 1] - prefix[i]

So the problem becomes:

Count pairs (i, j) where i < j and:

lower <= prefix[j] - prefix[i] <= upper

Rearrange this:

prefix[j] - upper <= prefix[i] <= prefix[j] - lower

For each right prefix sum prefix[j], we need to know how many earlier prefix sums fall inside a value interval.

That is the real problem.

Why Merge Sort Helps

During merge sort, we split the prefix sums into two halves.

For every prefix sum in the left half, we count how many prefix sums in the right half produce a valid difference.

Because both halves are sorted, we can count these pairs with two moving pointers.

Suppose left_value is from the left half.

We need right values x such that:

lower <= x - left_value <= upper

Rearrange:

left_value + lower <= x <= left_value + upper

Since the right half is sorted, all valid x values form one contiguous window.

We can move two pointers:

lo: first right value >= left_value + lower
hi: first right value >  left_value + upper

Then the number of valid right values is:

hi - lo

This gives linear counting during each merge level.

Algorithm

Build the prefix sum array.

Then run a recursive merge sort function over the prefix sums.

For each recursive range:

  1. Split into left half and right half.
  2. Count valid pairs fully inside the left half.
  3. Count valid pairs fully inside the right half.
  4. Count valid pairs crossing from left half to right half.
  5. Sort the current range before returning.

The crossing count is the important part.

For each prefix sum x in the left sorted half:

x + lower <= y <= x + upper

where y comes from the right sorted half.

Use two pointers to count all such y.

Walkthrough

Take:

nums = [-2, 5, -1]
lower = -2
upper = 2

Build prefix sums:

prefix = [0, -2, 3, 2]

Every range sum is a difference between two prefix sums.

The valid pairs are:

Pair of prefix indicesDifferenceSubarray
(0, 1)-2 - 0 = -2[0, 0]
(2, 3)2 - 3 = -1[2, 2]
(0, 3)2 - 0 = 2[0, 2]

So the answer is 3.

Merge sort counts these pairs while sorting the prefix array.

Correctness

Every subarray sum can be represented as the difference of two prefix sums.

For any subarray from index i to index j, its sum is:

prefix[j + 1] - prefix[i]

So counting valid range sums is exactly the same as counting prefix pairs (i, j) with i < j and:

lower <= prefix[j] - prefix[i] <= upper

The recursive algorithm counts all such pairs in three disjoint groups:

GroupMeaning
Left halfBoth prefix indices are in the left half
Right halfBoth prefix indices are in the right half
CrossingFirst prefix index is in the left half, second prefix index is in the right half

The recursive calls correctly count the left and right groups.

For crossing pairs, the left and right halves are sorted. For each left value x, all right values y satisfying:

x + lower <= y <= x + upper

form a contiguous interval in the sorted right half.

The two pointers find exactly the bounds of this interval. Therefore hi - lo counts exactly the valid crossing pairs for x.

Since every valid pair belongs to exactly one of the three groups, and each group is counted exactly once, the algorithm returns the correct count.

Complexity

MetricValueWhy
TimeO(n log n)Merge sort has log n levels, and each level does linear work
SpaceO(n)Prefix sums and merged arrays require linear space

The prefix array has length n + 1, so the complexity is still O(n log n).

Implementation

class Solution:
    def countRangeSum(self, nums: list[int], lower: int, upper: int) -> int:
        prefix = [0]

        for num in nums:
            prefix.append(prefix[-1] + num)

        def sort_and_count(arr: list[int]) -> tuple[list[int], int]:
            n = len(arr)

            if n <= 1:
                return arr, 0

            mid = n // 2

            left, left_count = sort_and_count(arr[:mid])
            right, right_count = sort_and_count(arr[mid:])

            count = left_count + right_count

            lo = 0
            hi = 0

            for x in left:
                while lo < len(right) and right[lo] < x + lower:
                    lo += 1

                while hi < len(right) and right[hi] <= x + upper:
                    hi += 1

                count += hi - lo

            merged = []
            i = 0
            j = 0

            while i < len(left) and j < len(right):
                if left[i] <= right[j]:
                    merged.append(left[i])
                    i += 1
                else:
                    merged.append(right[j])
                    j += 1

            merged.extend(left[i:])
            merged.extend(right[j:])

            return merged, count

        _, answer = sort_and_count(prefix)
        return answer

Code Explanation

We first build prefix sums:

prefix = [0]

for num in nums:
    prefix.append(prefix[-1] + num)

The initial 0 represents the sum before taking any elements. This is needed for subarrays that start at index 0.

The recursive function returns two things:

return sorted_array, count

The sorted array is needed by the parent call.

The count is the number of valid range sums inside that array segment.

The base case has one prefix sum:

if n <= 1:
    return arr, 0

A single prefix sum cannot form a pair, so the count is 0.

Then we split and recurse:

left, left_count = sort_and_count(arr[:mid])
right, right_count = sort_and_count(arr[mid:])

The recursive calls count pairs fully inside each half.

Then we count crossing pairs.

For each x in left, valid right values must satisfy:

x + lower <= y <= x + upper

The first pointer skips values that are too small:

while lo < len(right) and right[lo] < x + lower:
    lo += 1

The second pointer includes values that are still small enough:

while hi < len(right) and right[hi] <= x + upper:
    hi += 1

So this many values are valid:

count += hi - lo

Finally, we merge the two sorted halves so the parent level can use them.

Testing

def run_tests():
    s = Solution()

    assert s.countRangeSum([-2, 5, -1], -2, 2) == 3
    assert s.countRangeSum([0], 0, 0) == 1
    assert s.countRangeSum([1], 0, 0) == 0
    assert s.countRangeSum([1, -1], 0, 0) == 1
    assert s.countRangeSum([1, 2, 3], 3, 5) == 3
    assert s.countRangeSum([-1, -1, -1], -2, -1) == 5

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
[-2, 5, -1], [-2, 2]Official-style example
[0], [0, 0]Single zero forms one valid range
[1], [0, 0]Single value outside range
[1, -1], [0, 0]Valid zero-sum range
[1, 2, 3], [3, 5]Multiple positive ranges
[-1, -1, -1], [-2, -1]Negative sums and repeated values