Count arithmetic subsequences of length at least three using dynamic programming with one hash map per ending index.
Problem Restatement
We are given an integer array nums.
We need to count how many arithmetic subsequences exist in the array.
A sequence is arithmetic if:
- It has at least three elements.
- The difference between every two consecutive elements is the same.
For example:
[1, 3, 5, 7, 9]is arithmetic because every step increases by 2.
This is also arithmetic:
[7, 7, 7, 7]because every step has difference 0.
A subsequence can skip elements, but it must keep the original order.
For example:
[2, 5, 10]is a subsequence of:
[1, 2, 1, 2, 4, 1, 5, 10]The official problem asks for the number of arithmetic subsequences of length at least 3. The input length is at most 1000, and values may be as small as -2^31 or as large as 2^31 - 1. The answer is guaranteed to fit in a 32-bit integer.
Input and Output
| Item | Meaning |
|---|---|
| Input | Integer array nums |
| Output | Number of arithmetic subsequences |
| Minimum length | 3 |
| Order rule | Must keep original index order |
| Difference | Can be positive, zero, or negative |
Example function shape:
def numberOfArithmeticSlices(nums: list[int]) -> int:
...Examples
Example 1:
nums = [2, 4, 6, 8, 10]The arithmetic subsequences are:
[2, 4, 6]
[4, 6, 8]
[6, 8, 10]
[2, 4, 6, 8]
[4, 6, 8, 10]
[2, 4, 6, 8, 10]
[2, 6, 10]Answer:
7Example 2:
nums = [7, 7, 7, 7, 7]Every subsequence of length at least 3 is arithmetic because all differences are 0.
There are:
16valid subsequences.
Answer:
16First Thought: Generate All Subsequences
A direct approach is:
- Generate every subsequence.
- Keep only subsequences with length at least
3. - Check whether each one is arithmetic.
This is too slow.
An array of length n has:
2 ** nsubsequences.
Since n can be 1000, enumeration is impossible.
We need to count subsequences without explicitly building them.
Key Insight
An arithmetic subsequence can be extended if the next number keeps the same difference.
Suppose we have an arithmetic subsequence ending at index j with difference d.
If:
nums[i] - nums[j] == dthen we can append nums[i] to that subsequence.
So we track, for each ending index i, how many subsequences end there with each difference.
Define:
dp[i][d]as the number of subsequences ending at index i with common difference d.
Important detail:
dp[i][d] counts subsequences of length at least 2.
Why include length 2?
Because a pair is not a valid answer yet, but it can become valid after adding one more number.
For example:
[2, 4]does not count yet.
But when we later see 6, it becomes:
[2, 4, 6]which does count.
Transition
For every pair of indices:
j < icompute:
d = nums[i] - nums[j]Let:
count = dp[j][d]This means there are count subsequences of length at least 2 ending at j with difference d.
By appending nums[i], each of them becomes a valid arithmetic subsequence of length at least 3.
So we add:
answer += countThen update dp[i][d].
There are two sources:
- Existing subsequences ending at
jwith differenced, extended bynums[i]. - The new pair
[nums[j], nums[i]].
So:
dp[i][d] += count + 1The +1 represents the length-2 pair.
We do not add this pair to the answer yet, because the problem requires length at least 3.
Algorithm
Create:
dp = [defaultdict(int) for _ in nums]Initialize:
answer = 0For each ending index i:
- For every previous index
j < i:- Compute
d = nums[i] - nums[j]. - Let
count = dp[j][d]. - Add
counttoanswer. - Add
count + 1todp[i][d].
- Compute
Return answer.
Correctness
For each index i and difference d, dp[i][d] stores the number of arithmetic subsequences of length at least 2 that end at i with difference d.
When processing a pair (j, i), the difference is:
d = nums[i] - nums[j]Every subsequence counted in dp[j][d] can be extended by nums[i], because its last value is nums[j] and the new difference remains d.
Each such extension has length at least 3, so all dp[j][d] extensions should be added to the final answer.
The pair [nums[j], nums[i]] also has difference d, so it must be stored in dp[i][d] for possible future extensions. It is not added to the answer because its length is only 2.
Thus the update:
answer += dp[j][d]
dp[i][d] += dp[j][d] + 1counts exactly the newly formed valid arithmetic subsequences ending at i, while preserving length-2 pairs for later.
Every valid arithmetic subsequence of length at least 3 has a unique last two indices (j, i). It is counted exactly when the algorithm processes that pair and extends its prefix ending at j.
Therefore, the final answer is exactly the number of arithmetic subsequences.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n^2) | We inspect every pair (j, i) |
| Space | O(n^2) | In the worst case, each index can store many differences |
Here, n = len(nums).
Implementation
from collections import defaultdict
class Solution:
def numberOfArithmeticSlices(self, nums: list[int]) -> int:
n = len(nums)
dp = [defaultdict(int) for _ in range(n)]
answer = 0
for i in range(n):
for j in range(i):
diff = nums[i] - nums[j]
count = dp[j][diff]
answer += count
dp[i][diff] += count + 1
return answerCode Explanation
We keep one dictionary per index:
dp = [defaultdict(int) for _ in range(n)]For each i, dp[i] maps:
difference -> number of subsequences ending at i with that differenceWe try every previous index:
for j in range(i):The common difference between nums[j] and nums[i] is:
diff = nums[i] - nums[j]The number of extendable subsequences is:
count = dp[j][diff]These become valid length-at-least-3 subsequences after adding nums[i], so we add them:
answer += countThen we update the state at i:
dp[i][diff] += count + 1The count part extends old subsequences.
The +1 part creates the new length-2 pair [nums[j], nums[i]].
Testing
def run_tests():
s = Solution()
assert s.numberOfArithmeticSlices([2, 4, 6, 8, 10]) == 7
assert s.numberOfArithmeticSlices([7, 7, 7, 7, 7]) == 16
assert s.numberOfArithmeticSlices([1, 2, 3, 4]) == 3
assert s.numberOfArithmeticSlices([1, 3, 5, 7, 9]) == 7
assert s.numberOfArithmeticSlices([1, 1, 2, 5, 7]) == 0
assert s.numberOfArithmeticSlices([3, -1, -5, -9]) == 3
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
[2,4,6,8,10] | Checks standard example |
| All equal values | Checks zero difference |
[1,2,3,4] | Checks length-3 and length-4 subsequences |
| Odd arithmetic progression | Checks multiple subsequence lengths |
| Non-arithmetic values | Checks zero result |
| Negative difference | Checks decreasing arithmetic sequences |