Skip to content

LeetCode 398: Random Pick Index

A clear explanation of picking a uniformly random index for a target value using reservoir sampling, with an alternative hash map approach.

Problem Restatement

We are given an integer array nums.

The array may contain duplicate values.

We need to implement a class with one operation:

pick(target)

This operation returns a random index i such that:

nums[i] == target

If the target appears at multiple indices, each valid index must have equal probability of being returned.

The problem guarantees that target exists in nums. The constraints allow nums.length <= 2 * 10^4 and at most 10^4 calls to pick.

Input and Output

MethodInputOutput
Solution(nums)Integer array numsInitializes the object
pick(target)Integer targetRandom index where nums[index] == target

Example class shape:

class Solution:

    def __init__(self, nums: list[int]):
        ...

    def pick(self, target: int) -> int:
        ...

Examples

Example:

solution = Solution([1, 2, 3, 3, 3])

Call:

solution.pick(3)

The valid indices are:

[2, 3, 4]

So pick(3) should return 2, 3, or 4.

Each index should have probability:

1 / 3

Call:

solution.pick(1)

The only valid index is:

0

So the method must return:

0

First Thought: Store All Indices

A simple approach is to preprocess the array.

Build a hash map:

value -> list of indices

For:

nums = [1, 2, 3, 3, 3]

the map is:

{
    1: [0],
    2: [1],
    3: [2, 3, 4],
}

Then pick(target) chooses a random index from the list.

import random
from collections import defaultdict

class Solution:

    def __init__(self, nums: list[int]):
        self.indices = defaultdict(list)

        for i, num in enumerate(nums):
            self.indices[num].append(i)

    def pick(self, target: int) -> int:
        return random.choice(self.indices[target])

This is clean and fast.

Its tradeoff is memory: it stores all indices.

Key Insight

We can also solve the problem without storing all target indices.

When scanning nums, suppose we have seen the target count times.

When we see the next matching index, choose it with probability:

1 / count

This is reservoir sampling with reservoir size 1.

The idea:

Match numberWhat we do
1st matchChoose it with probability 1 / 1
2nd matchReplace answer with probability 1 / 2
3rd matchReplace answer with probability 1 / 3
kth matchReplace answer with probability 1 / k

At the end, every matching index has equal probability.

Algorithm

For each call to pick(target):

  1. Set count = 0.
  2. Set answer = -1.
  3. Scan all indices i.
  4. If nums[i] == target:
    • Increase count.
    • Replace answer with i with probability 1 / count.
  5. Return answer.

In Python, this probability can be written as:

if random.randrange(count) == 0:
    answer = i

random.randrange(count) returns one integer from 0 to count - 1, so the condition is true with probability 1 / count.

Correctness

Consider the jth occurrence of target.

When the algorithm sees this occurrence, it chooses that index with probability:

1 / j

Then it must survive all later replacement chances.

At the next occurrence, it is not replaced with probability:

j / (j + 1)

At the following occurrence, it is not replaced with probability:

(j + 1) / (j + 2)

This continues until there are m total occurrences.

So the final probability that the jth occurrence is returned is:

(1 / j) * (j / (j + 1)) * ((j + 1) / (j + 2)) * ... * ((m - 1) / m)

Everything cancels except:

1 / m

Thus every valid index has equal probability.

Since the target is guaranteed to exist, at least one matching index is seen, and the returned answer is always valid.

Complexity

Let n = len(nums).

MethodTimeSpace
ConstructorO(1)O(1)
pick(target)O(n)O(1)

This is the reservoir sampling version.

The hash map version has a different tradeoff:

MethodTimeSpace
ConstructorO(n)O(n)
pick(target)O(1) averageO(1) extra

Implementation

import random

class Solution:

    def __init__(self, nums: list[int]):
        self.nums = nums

    def pick(self, target: int) -> int:
        count = 0
        answer = -1

        for i, num in enumerate(self.nums):
            if num == target:
                count += 1

                if random.randrange(count) == 0:
                    answer = i

        return answer

Code Explanation

The constructor stores the array:

self.nums = nums

For each pick, we scan the array:

for i, num in enumerate(self.nums):

Whenever we find the target:

if num == target:

we increase the number of matches seen:

count += 1

Then we choose the current index with probability 1 / count:

if random.randrange(count) == 0:
    answer = i

At the end:

return answer

The answer is uniformly random among all matching indices.

Testing

Randomized code should be tested by checking validity and rough distribution, not by expecting one fixed output.

from collections import Counter

def test_solution():
    s = Solution([1, 2, 3, 3, 3])

    for _ in range(100):
        assert s.pick(1) == 0
        assert s.pick(2) == 1
        assert s.pick(3) in {2, 3, 4}

    counts = Counter(s.pick(3) for _ in range(6000))

    assert set(counts) == {2, 3, 4}

    for index in [2, 3, 4]:
        assert 1500 <= counts[index] <= 2500

    print("all tests passed")

test_solution()

Test meaning:

TestWhy
pick(1)Only one valid index
pick(2)Only one valid index
pick(3)Must return one of several valid indices
Many calls to pick(3)Basic sanity check for uniform randomness