A clear explanation of weighted random sampling using prefix sums and binary search.
Problem Restatement
We are given a 0-indexed array of positive integers w.
Each w[i] is the weight of index i.
We need to implement a class with a method:
pickIndex()This method returns an index from 0 to len(w) - 1.
The probability of returning index i must be:
w[i] / sum(w)So larger weights should be picked more often.
For example, if:
w = [1, 3]then index 0 should be picked with probability:
1 / 4and index 1 should be picked with probability:
3 / 4Since this is a randomized problem, many output sequences can be accepted. The important part is that the long-run distribution matches the weights.
Input and Output
| Item | Meaning |
|---|---|
| Constructor input | An array of positive integers w |
| Method output | One random index |
| Probability rule | Index i is returned with probability w[i] / sum(w) |
| Array length | 1 <= w.length <= 10^4 |
| Weight value | 1 <= w[i] <= 10^5 |
| Calls | pickIndex is called at most 10^4 times |
Class shape:
class Solution:
def __init__(self, w: list[int]):
...
def pickIndex(self) -> int:
...Examples
Consider:
w = [1, 3]The total weight is:
1 + 3 = 4So we can think of four slots:
[0, 1, 1, 1]Index 0 owns one slot.
Index 1 owns three slots.
If we pick one slot uniformly at random, then index 0 has probability 1 / 4, and index 1 has probability 3 / 4.
Now consider:
w = [2, 5, 3]The total weight is:
10The probability distribution should be:
| Index | Weight | Probability |
|---|---|---|
0 | 2 | 2 / 10 |
1 | 5 | 5 / 10 |
2 | 3 | 3 / 10 |
We do not need to actually build this repeated list:
[0, 0, 1, 1, 1, 1, 1, 2, 2, 2]That list explains the idea, but it is not memory-efficient.
First Thought: Expand the Array
A simple approach is to create a large array where each index appears as many times as its weight.
For:
w = [2, 5, 3]we would build:
expanded = [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]Then pickIndex() chooses one random element from expanded.
This gives the correct probability because every slot is equally likely.
But this approach can use too much memory.
The sum of weights can be large:
10^4 * 10^5 = 10^9So we cannot build an expanded array.
Key Insight
Instead of storing every repeated index, store the boundary where each index ends.
For:
w = [2, 5, 3]the prefix sums are:
[2, 7, 10]These prefix sums divide the range 1 through 10 into intervals:
| Random value | Picked index |
|---|---|
1, 2 | 0 |
3, 4, 5, 6, 7 | 1 |
8, 9, 10 | 2 |
Index 0 gets 2 values.
Index 1 gets 5 values.
Index 2 gets 3 values.
So the interval sizes exactly match the weights.
Now pickIndex() can:
- Generate a random integer from
1tototal_weight. - Find the first prefix sum greater than or equal to that random integer.
- Return that prefix sum’s index.
The search is binary search because the prefix sums are sorted.
Algorithm
During initialization:
- Create an empty array
prefix. - Keep a running total.
- For each weight, add it to the running total.
- Append the running total to
prefix.
During pickIndex():
- Generate a random integer
targetin the inclusive range:
1 <= target <= total_weight- Binary search for the leftmost index
isuch that:
prefix[i] >= target- Return
i.
Correctness
The prefix sum array partitions the integers from 1 to total_weight into consecutive ranges.
Index 0 receives the range:
1 through prefix[0]For every later index i, it receives the range:
prefix[i - 1] + 1 through prefix[i]The number of integers in this range is:
prefix[i] - prefix[i - 1]By construction, that value is exactly w[i].
Therefore, index i owns exactly w[i] random values out of sum(w) total values.
Since target is chosen uniformly from 1 through sum(w), the probability that target lands inside index i’s range is:
w[i] / sum(w)The binary search returns the first prefix sum greater than or equal to target, which is exactly the index whose range contains target.
So pickIndex() returns each index with the required probability.
Complexity
Let n = len(w).
| Operation | Time | Space |
|---|---|---|
| Constructor | O(n) | O(n) |
pickIndex() | O(log n) | O(1) |
The constructor builds one prefix sum array.
Each call to pickIndex() performs one binary search.
Implementation
import random
from bisect import bisect_left
class Solution:
def __init__(self, w: list[int]):
self.prefix = []
total = 0
for weight in w:
total += weight
self.prefix.append(total)
self.total = total
def pickIndex(self) -> int:
target = random.randint(1, self.total)
return bisect_left(self.prefix, target)Code Explanation
The constructor builds cumulative weights:
self.prefix = []
total = 0For each weight:
total += weight
self.prefix.append(total)If w = [2, 5, 3], then self.prefix becomes:
[2, 7, 10]The total weight is stored separately:
self.total = totalIn pickIndex(), we choose a random target:
target = random.randint(1, self.total)This is inclusive on both ends.
Then we find the first prefix sum that is at least target:
bisect_left(self.prefix, target)For example, if:
self.prefix = [2, 7, 10]then:
| Target | First prefix >= target | Returned index |
|---|---|---|
1 | 2 | 0 |
2 | 2 | 0 |
3 | 7 | 1 |
7 | 7 | 1 |
8 | 10 | 2 |
10 | 10 | 2 |
This gives exactly the weighted distribution.
Manual Binary Search Version
Some interviews prefer writing the binary search directly.
import random
class Solution:
def __init__(self, w: list[int]):
self.prefix = []
total = 0
for weight in w:
total += weight
self.prefix.append(total)
self.total = total
def pickIndex(self) -> int:
target = random.randint(1, self.total)
left = 0
right = len(self.prefix) - 1
while left < right:
mid = (left + right) // 2
if self.prefix[mid] >= target:
right = mid
else:
left = mid + 1
return leftThe invariant is that the answer is always inside [left, right].
If self.prefix[mid] >= target, then mid might be the answer, but there could be an earlier valid index. So we move right to mid.
If self.prefix[mid] < target, then mid cannot be the answer. So we move left to mid + 1.
When the loop ends, left == right, and that index is the first prefix sum greater than or equal to target.
Testing
Randomized algorithms are tested differently from deterministic algorithms.
We should not expect an exact sequence of outputs.
Instead, we can test that every returned index is valid and that the distribution is approximately correct after many trials.
from collections import Counter
def run_tests():
s = Solution([1, 3])
for _ in range(100):
idx = s.pickIndex()
assert idx in [0, 1]
s = Solution([2, 5, 3])
counts = Counter(s.pickIndex() for _ in range(10000))
assert counts[0] < counts[1]
assert counts[2] < counts[1]
s = Solution([10])
for _ in range(100):
assert s.pickIndex() == 0
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
[1, 3] | Checks valid indices for a simple weighted case |
[2, 5, 3] | Checks that the largest weight is picked most often |
[10] | Checks single-index input |