# LeetCode 327: Count of Range Sum

## Problem Restatement

We are given an integer array `nums` and two integers `lower` and `upper`.

A range sum is the sum of a contiguous subarray:

```text
S(i, j) = nums[i] + nums[i + 1] + ... + nums[j]
```

where `i <= j`.

We need to count how many range sums lie inside this inclusive range:

```text
lower <= S(i, j) <= upper
```

The naive `O(n^2)` solution is too slow, so we need something better.

## Input and Output

| Item | Meaning |
|---|---|
| Input | `nums`, `lower`, and `upper` |
| Output | Number of subarray sums in `[lower, upper]` |
| Constraints | `nums.length` can be up to `10^5` |
| Required idea | Better than `O(n^2)` |

Function shape:

```python
def countRangeSum(nums: list[int], lower: int, upper: int) -> int:
    ...
```

## Examples

Example 1:

```text
Input: nums = [-2, 5, -1], lower = -2, upper = 2
Output: 3
```

The valid ranges are:

| Range | Sum |
|---|---:|
| `[0, 0]` | `-2` |
| `[2, 2]` | `-1` |
| `[0, 2]` | `2` |

So the answer is `3`.

Example 2:

```text
Input: nums = [0], lower = 0, upper = 0
Output: 1
```

There is only one subarray: `[0]`.

Its sum is `0`, which lies in `[0, 0]`.

## First Thought: Brute Force

The direct method is to enumerate every subarray.

For each starting index `i`, keep extending the ending index `j`, and update the running sum.

```python
class Solution:
    def countRangeSum(self, nums: list[int], lower: int, upper: int) -> int:
        ans = 0

        for i in range(len(nums)):
            total = 0

            for j in range(i, len(nums)):
                total += nums[j]

                if lower <= total <= upper:
                    ans += 1

        return ans
```

This gives the right answer, but it checks all subarrays.

## Problem With Brute Force

There are about `n^2 / 2` subarrays.

So the time complexity is:

```text
O(n^2)
```

For `n = 100000`, this is far too large.

We need to count valid ranges without explicitly checking every subarray.

## Key Insight

Use prefix sums.

Define:

```text
prefix[0] = 0
prefix[i + 1] = nums[0] + nums[1] + ... + nums[i]
```

Then the sum of subarray `nums[i:j]` is:

```text
prefix[j] - prefix[i]
```

For an inclusive subarray `nums[i]` through `nums[j]`, the range sum is:

```text
prefix[j + 1] - prefix[i]
```

So the problem becomes:

Count pairs `(i, j)` where `i < j` and:

```text
lower <= prefix[j] - prefix[i] <= upper
```

Rearrange this:

```text
prefix[j] - upper <= prefix[i] <= prefix[j] - lower
```

For each right prefix sum `prefix[j]`, we need to know how many earlier prefix sums fall inside a value interval.

That is the real problem.

## Why Merge Sort Helps

During merge sort, we split the prefix sums into two halves.

For every prefix sum in the left half, we count how many prefix sums in the right half produce a valid difference.

Because both halves are sorted, we can count these pairs with two moving pointers.

Suppose `left_value` is from the left half.

We need right values `x` such that:

```text
lower <= x - left_value <= upper
```

Rearrange:

```text
left_value + lower <= x <= left_value + upper
```

Since the right half is sorted, all valid `x` values form one contiguous window.

We can move two pointers:

```text
lo: first right value >= left_value + lower
hi: first right value >  left_value + upper
```

Then the number of valid right values is:

```text
hi - lo
```

This gives linear counting during each merge level.

## Algorithm

Build the prefix sum array.

Then run a recursive merge sort function over the prefix sums.

For each recursive range:

1. Split into left half and right half.
2. Count valid pairs fully inside the left half.
3. Count valid pairs fully inside the right half.
4. Count valid pairs crossing from left half to right half.
5. Sort the current range before returning.

The crossing count is the important part.

For each prefix sum `x` in the left sorted half:

```text
x + lower <= y <= x + upper
```

where `y` comes from the right sorted half.

Use two pointers to count all such `y`.

## Walkthrough

Take:

```text
nums = [-2, 5, -1]
lower = -2
upper = 2
```

Build prefix sums:

```text
prefix = [0, -2, 3, 2]
```

Every range sum is a difference between two prefix sums.

The valid pairs are:

| Pair of prefix indices | Difference | Subarray |
|---|---:|---|
| `(0, 1)` | `-2 - 0 = -2` | `[0, 0]` |
| `(2, 3)` | `2 - 3 = -1` | `[2, 2]` |
| `(0, 3)` | `2 - 0 = 2` | `[0, 2]` |

So the answer is `3`.

Merge sort counts these pairs while sorting the prefix array.

## Correctness

Every subarray sum can be represented as the difference of two prefix sums.

For any subarray from index `i` to index `j`, its sum is:

```text
prefix[j + 1] - prefix[i]
```

So counting valid range sums is exactly the same as counting prefix pairs `(i, j)` with `i < j` and:

```text
lower <= prefix[j] - prefix[i] <= upper
```

The recursive algorithm counts all such pairs in three disjoint groups:

| Group | Meaning |
|---|---|
| Left half | Both prefix indices are in the left half |
| Right half | Both prefix indices are in the right half |
| Crossing | First prefix index is in the left half, second prefix index is in the right half |

The recursive calls correctly count the left and right groups.

For crossing pairs, the left and right halves are sorted. For each left value `x`, all right values `y` satisfying:

```text
x + lower <= y <= x + upper
```

form a contiguous interval in the sorted right half.

The two pointers find exactly the bounds of this interval. Therefore `hi - lo` counts exactly the valid crossing pairs for `x`.

Since every valid pair belongs to exactly one of the three groups, and each group is counted exactly once, the algorithm returns the correct count.

## Complexity

| Metric | Value | Why |
|---|---|---|
| Time | `O(n log n)` | Merge sort has `log n` levels, and each level does linear work |
| Space | `O(n)` | Prefix sums and merged arrays require linear space |

The prefix array has length `n + 1`, so the complexity is still `O(n log n)`.

## Implementation

```python
class Solution:
    def countRangeSum(self, nums: list[int], lower: int, upper: int) -> int:
        prefix = [0]

        for num in nums:
            prefix.append(prefix[-1] + num)

        def sort_and_count(arr: list[int]) -> tuple[list[int], int]:
            n = len(arr)

            if n <= 1:
                return arr, 0

            mid = n // 2

            left, left_count = sort_and_count(arr[:mid])
            right, right_count = sort_and_count(arr[mid:])

            count = left_count + right_count

            lo = 0
            hi = 0

            for x in left:
                while lo < len(right) and right[lo] < x + lower:
                    lo += 1

                while hi < len(right) and right[hi] <= x + upper:
                    hi += 1

                count += hi - lo

            merged = []
            i = 0
            j = 0

            while i < len(left) and j < len(right):
                if left[i] <= right[j]:
                    merged.append(left[i])
                    i += 1
                else:
                    merged.append(right[j])
                    j += 1

            merged.extend(left[i:])
            merged.extend(right[j:])

            return merged, count

        _, answer = sort_and_count(prefix)
        return answer
```

## Code Explanation

We first build prefix sums:

```python
prefix = [0]

for num in nums:
    prefix.append(prefix[-1] + num)
```

The initial `0` represents the sum before taking any elements. This is needed for subarrays that start at index `0`.

The recursive function returns two things:

```python
return sorted_array, count
```

The sorted array is needed by the parent call.

The count is the number of valid range sums inside that array segment.

The base case has one prefix sum:

```python
if n <= 1:
    return arr, 0
```

A single prefix sum cannot form a pair, so the count is `0`.

Then we split and recurse:

```python
left, left_count = sort_and_count(arr[:mid])
right, right_count = sort_and_count(arr[mid:])
```

The recursive calls count pairs fully inside each half.

Then we count crossing pairs.

For each `x` in `left`, valid right values must satisfy:

```text
x + lower <= y <= x + upper
```

The first pointer skips values that are too small:

```python
while lo < len(right) and right[lo] < x + lower:
    lo += 1
```

The second pointer includes values that are still small enough:

```python
while hi < len(right) and right[hi] <= x + upper:
    hi += 1
```

So this many values are valid:

```python
count += hi - lo
```

Finally, we merge the two sorted halves so the parent level can use them.

## Testing

```python
def run_tests():
    s = Solution()

    assert s.countRangeSum([-2, 5, -1], -2, 2) == 3
    assert s.countRangeSum([0], 0, 0) == 1
    assert s.countRangeSum([1], 0, 0) == 0
    assert s.countRangeSum([1, -1], 0, 0) == 1
    assert s.countRangeSum([1, 2, 3], 3, 5) == 3
    assert s.countRangeSum([-1, -1, -1], -2, -1) == 5

    print("all tests passed")

run_tests()
```

Test meaning:

| Test | Why |
|---|---|
| `[-2, 5, -1]`, `[-2, 2]` | Official-style example |
| `[0]`, `[0, 0]` | Single zero forms one valid range |
| `[1]`, `[0, 0]` | Single value outside range |
| `[1, -1]`, `[0, 0]` | Valid zero-sum range |
| `[1, 2, 3]`, `[3, 5]` | Multiple positive ranges |
| `[-1, -1, -1]`, `[-2, -1]` | Negative sums and repeated values |

