A clear explanation of picking a uniformly random index for a target value using reservoir sampling, with an alternative hash map approach.
Problem Restatement
We are given an integer array nums.
The array may contain duplicate values.
We need to implement a class with one operation:
pick(target)This operation returns a random index i such that:
nums[i] == targetIf the target appears at multiple indices, each valid index must have equal probability of being returned.
The problem guarantees that target exists in nums. The constraints allow nums.length <= 2 * 10^4 and at most 10^4 calls to pick.
Input and Output
| Method | Input | Output |
|---|---|---|
Solution(nums) | Integer array nums | Initializes the object |
pick(target) | Integer target | Random index where nums[index] == target |
Example class shape:
class Solution:
def __init__(self, nums: list[int]):
...
def pick(self, target: int) -> int:
...Examples
Example:
solution = Solution([1, 2, 3, 3, 3])Call:
solution.pick(3)The valid indices are:
[2, 3, 4]So pick(3) should return 2, 3, or 4.
Each index should have probability:
1 / 3Call:
solution.pick(1)The only valid index is:
0So the method must return:
0First Thought: Store All Indices
A simple approach is to preprocess the array.
Build a hash map:
value -> list of indicesFor:
nums = [1, 2, 3, 3, 3]the map is:
{
1: [0],
2: [1],
3: [2, 3, 4],
}Then pick(target) chooses a random index from the list.
import random
from collections import defaultdict
class Solution:
def __init__(self, nums: list[int]):
self.indices = defaultdict(list)
for i, num in enumerate(nums):
self.indices[num].append(i)
def pick(self, target: int) -> int:
return random.choice(self.indices[target])This is clean and fast.
Its tradeoff is memory: it stores all indices.
Key Insight
We can also solve the problem without storing all target indices.
When scanning nums, suppose we have seen the target count times.
When we see the next matching index, choose it with probability:
1 / countThis is reservoir sampling with reservoir size 1.
The idea:
| Match number | What we do |
|---|---|
| 1st match | Choose it with probability 1 / 1 |
| 2nd match | Replace answer with probability 1 / 2 |
| 3rd match | Replace answer with probability 1 / 3 |
| kth match | Replace answer with probability 1 / k |
At the end, every matching index has equal probability.
Algorithm
For each call to pick(target):
- Set
count = 0. - Set
answer = -1. - Scan all indices
i. - If
nums[i] == target:- Increase
count. - Replace
answerwithiwith probability1 / count.
- Increase
- Return
answer.
In Python, this probability can be written as:
if random.randrange(count) == 0:
answer = irandom.randrange(count) returns one integer from 0 to count - 1, so the condition is true with probability 1 / count.
Correctness
Consider the jth occurrence of target.
When the algorithm sees this occurrence, it chooses that index with probability:
1 / jThen it must survive all later replacement chances.
At the next occurrence, it is not replaced with probability:
j / (j + 1)At the following occurrence, it is not replaced with probability:
(j + 1) / (j + 2)This continues until there are m total occurrences.
So the final probability that the jth occurrence is returned is:
(1 / j) * (j / (j + 1)) * ((j + 1) / (j + 2)) * ... * ((m - 1) / m)Everything cancels except:
1 / mThus every valid index has equal probability.
Since the target is guaranteed to exist, at least one matching index is seen, and the returned answer is always valid.
Complexity
Let n = len(nums).
| Method | Time | Space |
|---|---|---|
| Constructor | O(1) | O(1) |
pick(target) | O(n) | O(1) |
This is the reservoir sampling version.
The hash map version has a different tradeoff:
| Method | Time | Space |
|---|---|---|
| Constructor | O(n) | O(n) |
pick(target) | O(1) average | O(1) extra |
Implementation
import random
class Solution:
def __init__(self, nums: list[int]):
self.nums = nums
def pick(self, target: int) -> int:
count = 0
answer = -1
for i, num in enumerate(self.nums):
if num == target:
count += 1
if random.randrange(count) == 0:
answer = i
return answerCode Explanation
The constructor stores the array:
self.nums = numsFor each pick, we scan the array:
for i, num in enumerate(self.nums):Whenever we find the target:
if num == target:we increase the number of matches seen:
count += 1Then we choose the current index with probability 1 / count:
if random.randrange(count) == 0:
answer = iAt the end:
return answerThe answer is uniformly random among all matching indices.
Testing
Randomized code should be tested by checking validity and rough distribution, not by expecting one fixed output.
from collections import Counter
def test_solution():
s = Solution([1, 2, 3, 3, 3])
for _ in range(100):
assert s.pick(1) == 0
assert s.pick(2) == 1
assert s.pick(3) in {2, 3, 4}
counts = Counter(s.pick(3) for _ in range(6000))
assert set(counts) == {2, 3, 4}
for index in [2, 3, 4]:
assert 1500 <= counts[index] <= 2500
print("all tests passed")
test_solution()Test meaning:
| Test | Why |
|---|---|
pick(1) | Only one valid index |
pick(2) | Only one valid index |
pick(3) | Must return one of several valid indices |
Many calls to pick(3) | Basic sanity check for uniform randomness |