A clear explanation of counting subarrays with exactly k distinct integers using the at-most-k sliding window trick.
Problem Restatement
We are given an integer array nums and an integer k.
A good subarray is a contiguous subarray that contains exactly k different integers.
We need to return the number of good subarrays.
For example:
nums = [1, 2, 1, 2, 3]
k = 2The good subarrays are:
[1, 2]
[2, 1]
[1, 2]
[2, 3]
[1, 2, 1]
[2, 1, 2]
[1, 2, 1, 2]So the answer is 7.
The official constraints are 1 <= nums.length <= 2 * 10^4 and 1 <= nums[i], k <= nums.length.
Input and Output
| Item | Meaning |
|---|---|
| Input | Integer array nums and integer k |
| Output | Number of contiguous subarrays with exactly k distinct integers |
| Subarray | Must be contiguous |
| Distinct count | Count unique values inside the subarray |
Function shape:
def subarraysWithKDistinct(nums: list[int], k: int) -> int:
...Examples
Example 1:
nums = [1, 2, 1, 2, 3]
k = 2The subarrays with exactly 2 distinct integers are:
[1, 2]
[2, 1]
[1, 2]
[2, 3]
[1, 2, 1]
[2, 1, 2]
[1, 2, 1, 2]Answer:
7Example 2:
nums = [1, 2, 1, 3, 4]
k = 3The good subarrays are:
[1, 2, 1, 3]
[2, 1, 3]
[1, 3, 4]Answer:
3First Thought: Check Every Subarray
The direct solution is to enumerate every subarray.
For each starting index, extend the subarray to the right and keep a set of distinct values.
Whenever the set size is exactly k, increment the answer.
class Solution:
def subarraysWithKDistinct(self, nums: list[int], k: int) -> int:
n = len(nums)
answer = 0
for left in range(n):
distinct = set()
for right in range(left, n):
distinct.add(nums[right])
if len(distinct) == k:
answer += 1
elif len(distinct) > k:
break
return answerThis works, but it can still be too slow.
Problem With Brute Force
There are O(n^2) subarrays.
Since nums.length can be as large as 2 * 10^4, checking all subarrays is too expensive.
We need a linear-time method.
Key Insight
Counting subarrays with exactly k distinct integers is hard directly.
Counting subarrays with at most k distinct integers is much easier with a sliding window.
The key identity is:
exactly(k) = at_most(k) - at_most(k - 1)Why this works:
| Count | Includes subarrays with |
|---|---|
at_most(k) | 1, 2, ..., k distinct values |
at_most(k - 1) | 1, 2, ..., k - 1 distinct values |
Subtracting removes all subarrays with fewer than k distinct values.
Only subarrays with exactly k distinct values remain.
Counting Subarrays With At Most K Distinct Values
Use a sliding window [left, right].
Maintain a frequency map for values inside the window.
When adding nums[right] makes the number of distinct values greater than k, move left until the window has at most k distinct values again.
For each right, once the window is valid, every subarray ending at right and starting between left and right is valid.
The number of such subarrays is:
right - left + 1Algorithm
Define a helper:
at_most(k)It returns the number of subarrays with at most k distinct values.
Inside at_most(k):
- Initialize
left = 0,answer = 0, and an empty frequency map. - Expand
rightfrom left to right. - Add
nums[right]to the frequency map. - If the number of distinct values becomes greater than
k, moveleftuntil valid again. - Add
right - left + 1toanswer.
Then return:
at_most(k) - at_most(k - 1)Correctness
For at_most(k), the sliding window maintains the invariant that the current window has at most k distinct values after shrinking.
For a fixed right, once the window [left, right] is valid, every subarray ending at right and starting at any index from left to right is also valid. Removing elements from the left cannot increase the number of distinct values.
There are exactly right - left + 1 such subarrays, so the helper counts all valid subarrays ending at each right.
Every subarray has exactly one ending index, so at_most(k) counts every subarray with at most k distinct values exactly once.
The final answer subtracts the number of subarrays with at most k - 1 distinct values from the number of subarrays with at most k distinct values. This leaves exactly the subarrays with k distinct values.
Therefore, the algorithm returns the correct count.
Complexity
Let n = len(nums).
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each at_most call moves each pointer at most n times |
| Space | O(n) | The frequency map can store up to n distinct values |
The helper runs twice, so the time is still linear.
Implementation
from collections import defaultdict
class Solution:
def subarraysWithKDistinct(self, nums: list[int], k: int) -> int:
def at_most(limit: int) -> int:
if limit == 0:
return 0
freq = defaultdict(int)
left = 0
answer = 0
distinct = 0
for right, value in enumerate(nums):
if freq[value] == 0:
distinct += 1
freq[value] += 1
while distinct > limit:
left_value = nums[left]
freq[left_value] -= 1
if freq[left_value] == 0:
distinct -= 1
left += 1
answer += right - left + 1
return answer
return at_most(k) - at_most(k - 1)Code Explanation
The helper counts subarrays with at most limit distinct values:
def at_most(limit: int) -> int:When limit == 0, no non-empty subarray can be valid:
if limit == 0:
return 0We use a frequency map:
freq = defaultdict(int)The variable distinct stores how many values currently have positive frequency in the window:
distinct = 0When a new value enters the window for the first time, increase distinct:
if freq[value] == 0:
distinct += 1
freq[value] += 1If the window has too many distinct values, shrink from the left:
while distinct > limit:When a value’s frequency becomes zero, it leaves the window completely:
if freq[left_value] == 0:
distinct -= 1After the window is valid, all subarrays ending at right and starting from left through right are valid:
answer += right - left + 1Finally:
return at_most(k) - at_most(k - 1)keeps only subarrays with exactly k distinct values.
Testing
def run_tests():
s = Solution()
assert s.subarraysWithKDistinct([1, 2, 1, 2, 3], 2) == 7
assert s.subarraysWithKDistinct([1, 2, 1, 3, 4], 3) == 3
assert s.subarraysWithKDistinct([1, 2, 2], 1) == 4
assert s.subarraysWithKDistinct([1, 1, 1], 1) == 6
assert s.subarraysWithKDistinct([1, 2, 3], 3) == 1
print("all tests passed")
run_tests()| Test | Expected | Why |
|---|---|---|
[1,2,1,2,3], 2 | 7 | Standard sample |
[1,2,1,3,4], 3 | 3 | Standard sample |
[1,2,2], 1 | 4 | Counts repeated-value subarrays |
[1,1,1], 1 | 6 | Every subarray is valid |
[1,2,3], 3 | 1 | Only the full array has three distinct values |