Skip to content

LeetCode 528: Random Pick with Weight

A clear explanation of weighted random sampling using prefix sums and binary search.

Problem Restatement

We are given a 0-indexed array of positive integers w.

Each w[i] is the weight of index i.

We need to implement a class with a method:

pickIndex()

This method returns an index from 0 to len(w) - 1.

The probability of returning index i must be:

w[i] / sum(w)

So larger weights should be picked more often.

For example, if:

w = [1, 3]

then index 0 should be picked with probability:

1 / 4

and index 1 should be picked with probability:

3 / 4

Since this is a randomized problem, many output sequences can be accepted. The important part is that the long-run distribution matches the weights.

Input and Output

ItemMeaning
Constructor inputAn array of positive integers w
Method outputOne random index
Probability ruleIndex i is returned with probability w[i] / sum(w)
Array length1 <= w.length <= 10^4
Weight value1 <= w[i] <= 10^5
CallspickIndex is called at most 10^4 times

Class shape:

class Solution:

    def __init__(self, w: list[int]):
        ...

    def pickIndex(self) -> int:
        ...

Examples

Consider:

w = [1, 3]

The total weight is:

1 + 3 = 4

So we can think of four slots:

[0, 1, 1, 1]

Index 0 owns one slot.

Index 1 owns three slots.

If we pick one slot uniformly at random, then index 0 has probability 1 / 4, and index 1 has probability 3 / 4.

Now consider:

w = [2, 5, 3]

The total weight is:

10

The probability distribution should be:

IndexWeightProbability
022 / 10
155 / 10
233 / 10

We do not need to actually build this repeated list:

[0, 0, 1, 1, 1, 1, 1, 2, 2, 2]

That list explains the idea, but it is not memory-efficient.

First Thought: Expand the Array

A simple approach is to create a large array where each index appears as many times as its weight.

For:

w = [2, 5, 3]

we would build:

expanded = [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]

Then pickIndex() chooses one random element from expanded.

This gives the correct probability because every slot is equally likely.

But this approach can use too much memory.

The sum of weights can be large:

10^4 * 10^5 = 10^9

So we cannot build an expanded array.

Key Insight

Instead of storing every repeated index, store the boundary where each index ends.

For:

w = [2, 5, 3]

the prefix sums are:

[2, 7, 10]

These prefix sums divide the range 1 through 10 into intervals:

Random valuePicked index
1, 20
3, 4, 5, 6, 71
8, 9, 102

Index 0 gets 2 values.

Index 1 gets 5 values.

Index 2 gets 3 values.

So the interval sizes exactly match the weights.

Now pickIndex() can:

  1. Generate a random integer from 1 to total_weight.
  2. Find the first prefix sum greater than or equal to that random integer.
  3. Return that prefix sum’s index.

The search is binary search because the prefix sums are sorted.

Algorithm

During initialization:

  1. Create an empty array prefix.
  2. Keep a running total.
  3. For each weight, add it to the running total.
  4. Append the running total to prefix.

During pickIndex():

  1. Generate a random integer target in the inclusive range:
1 <= target <= total_weight
  1. Binary search for the leftmost index i such that:
prefix[i] >= target
  1. Return i.

Correctness

The prefix sum array partitions the integers from 1 to total_weight into consecutive ranges.

Index 0 receives the range:

1 through prefix[0]

For every later index i, it receives the range:

prefix[i - 1] + 1 through prefix[i]

The number of integers in this range is:

prefix[i] - prefix[i - 1]

By construction, that value is exactly w[i].

Therefore, index i owns exactly w[i] random values out of sum(w) total values.

Since target is chosen uniformly from 1 through sum(w), the probability that target lands inside index i’s range is:

w[i] / sum(w)

The binary search returns the first prefix sum greater than or equal to target, which is exactly the index whose range contains target.

So pickIndex() returns each index with the required probability.

Complexity

Let n = len(w).

OperationTimeSpace
ConstructorO(n)O(n)
pickIndex()O(log n)O(1)

The constructor builds one prefix sum array.

Each call to pickIndex() performs one binary search.

Implementation

import random
from bisect import bisect_left

class Solution:

    def __init__(self, w: list[int]):
        self.prefix = []
        total = 0

        for weight in w:
            total += weight
            self.prefix.append(total)

        self.total = total

    def pickIndex(self) -> int:
        target = random.randint(1, self.total)
        return bisect_left(self.prefix, target)

Code Explanation

The constructor builds cumulative weights:

self.prefix = []
total = 0

For each weight:

total += weight
self.prefix.append(total)

If w = [2, 5, 3], then self.prefix becomes:

[2, 7, 10]

The total weight is stored separately:

self.total = total

In pickIndex(), we choose a random target:

target = random.randint(1, self.total)

This is inclusive on both ends.

Then we find the first prefix sum that is at least target:

bisect_left(self.prefix, target)

For example, if:

self.prefix = [2, 7, 10]

then:

TargetFirst prefix >= targetReturned index
120
220
371
771
8102
10102

This gives exactly the weighted distribution.

Manual Binary Search Version

Some interviews prefer writing the binary search directly.

import random

class Solution:

    def __init__(self, w: list[int]):
        self.prefix = []
        total = 0

        for weight in w:
            total += weight
            self.prefix.append(total)

        self.total = total

    def pickIndex(self) -> int:
        target = random.randint(1, self.total)

        left = 0
        right = len(self.prefix) - 1

        while left < right:
            mid = (left + right) // 2

            if self.prefix[mid] >= target:
                right = mid
            else:
                left = mid + 1

        return left

The invariant is that the answer is always inside [left, right].

If self.prefix[mid] >= target, then mid might be the answer, but there could be an earlier valid index. So we move right to mid.

If self.prefix[mid] < target, then mid cannot be the answer. So we move left to mid + 1.

When the loop ends, left == right, and that index is the first prefix sum greater than or equal to target.

Testing

Randomized algorithms are tested differently from deterministic algorithms.

We should not expect an exact sequence of outputs.

Instead, we can test that every returned index is valid and that the distribution is approximately correct after many trials.

from collections import Counter

def run_tests():
    s = Solution([1, 3])

    for _ in range(100):
        idx = s.pickIndex()
        assert idx in [0, 1]

    s = Solution([2, 5, 3])
    counts = Counter(s.pickIndex() for _ in range(10000))

    assert counts[0] < counts[1]
    assert counts[2] < counts[1]

    s = Solution([10])
    for _ in range(100):
        assert s.pickIndex() == 0

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
[1, 3]Checks valid indices for a simple weighted case
[2, 5, 3]Checks that the largest weight is picked most often
[10]Checks single-index input