# LeetCode 398: Random Pick Index

## Problem Restatement

We are given an integer array `nums`.

The array may contain duplicate values.

We need to implement a class with one operation:

```python
pick(target)
```

This operation returns a random index `i` such that:

```python
nums[i] == target
```

If the target appears at multiple indices, each valid index must have equal probability of being returned.

The problem guarantees that `target` exists in `nums`. The constraints allow `nums.length <= 2 * 10^4` and at most `10^4` calls to `pick`.

## Input and Output

| Method | Input | Output |
|---|---|---|
| `Solution(nums)` | Integer array `nums` | Initializes the object |
| `pick(target)` | Integer `target` | Random index where `nums[index] == target` |

Example class shape:

```python
class Solution:

    def __init__(self, nums: list[int]):
        ...

    def pick(self, target: int) -> int:
        ...
```

## Examples

Example:

```python
solution = Solution([1, 2, 3, 3, 3])
```

Call:

```python
solution.pick(3)
```

The valid indices are:

```python
[2, 3, 4]
```

So `pick(3)` should return `2`, `3`, or `4`.

Each index should have probability:

```text
1 / 3
```

Call:

```python
solution.pick(1)
```

The only valid index is:

```python
0
```

So the method must return:

```python
0
```

## First Thought: Store All Indices

A simple approach is to preprocess the array.

Build a hash map:

```python
value -> list of indices
```

For:

```python
nums = [1, 2, 3, 3, 3]
```

the map is:

```python
{
    1: [0],
    2: [1],
    3: [2, 3, 4],
}
```

Then `pick(target)` chooses a random index from the list.

```python
import random
from collections import defaultdict

class Solution:

    def __init__(self, nums: list[int]):
        self.indices = defaultdict(list)

        for i, num in enumerate(nums):
            self.indices[num].append(i)

    def pick(self, target: int) -> int:
        return random.choice(self.indices[target])
```

This is clean and fast.

Its tradeoff is memory: it stores all indices.

## Key Insight

We can also solve the problem without storing all target indices.

When scanning `nums`, suppose we have seen the target `count` times.

When we see the next matching index, choose it with probability:

```text
1 / count
```

This is reservoir sampling with reservoir size `1`.

The idea:

| Match number | What we do |
|---|---|
| 1st match | Choose it with probability `1 / 1` |
| 2nd match | Replace answer with probability `1 / 2` |
| 3rd match | Replace answer with probability `1 / 3` |
| kth match | Replace answer with probability `1 / k` |

At the end, every matching index has equal probability.

## Algorithm

For each call to `pick(target)`:

1. Set `count = 0`.
2. Set `answer = -1`.
3. Scan all indices `i`.
4. If `nums[i] == target`:
   - Increase `count`.
   - Replace `answer` with `i` with probability `1 / count`.
5. Return `answer`.

In Python, this probability can be written as:

```python
if random.randrange(count) == 0:
    answer = i
```

`random.randrange(count)` returns one integer from `0` to `count - 1`, so the condition is true with probability `1 / count`.

## Correctness

Consider the `j`th occurrence of `target`.

When the algorithm sees this occurrence, it chooses that index with probability:

```text
1 / j
```

Then it must survive all later replacement chances.

At the next occurrence, it is not replaced with probability:

```text
j / (j + 1)
```

At the following occurrence, it is not replaced with probability:

```text
(j + 1) / (j + 2)
```

This continues until there are `m` total occurrences.

So the final probability that the `j`th occurrence is returned is:

```text
(1 / j) * (j / (j + 1)) * ((j + 1) / (j + 2)) * ... * ((m - 1) / m)
```

Everything cancels except:

```text
1 / m
```

Thus every valid index has equal probability.

Since the target is guaranteed to exist, at least one matching index is seen, and the returned answer is always valid.

## Complexity

Let `n = len(nums)`.

| Method | Time | Space |
|---|---|---|
| Constructor | `O(1)` | `O(1)` |
| `pick(target)` | `O(n)` | `O(1)` |

This is the reservoir sampling version.

The hash map version has a different tradeoff:

| Method | Time | Space |
|---|---|---|
| Constructor | `O(n)` | `O(n)` |
| `pick(target)` | `O(1)` average | `O(1)` extra |

## Implementation

```python
import random

class Solution:

    def __init__(self, nums: list[int]):
        self.nums = nums

    def pick(self, target: int) -> int:
        count = 0
        answer = -1

        for i, num in enumerate(self.nums):
            if num == target:
                count += 1

                if random.randrange(count) == 0:
                    answer = i

        return answer
```

## Code Explanation

The constructor stores the array:

```python
self.nums = nums
```

For each `pick`, we scan the array:

```python
for i, num in enumerate(self.nums):
```

Whenever we find the target:

```python
if num == target:
```

we increase the number of matches seen:

```python
count += 1
```

Then we choose the current index with probability `1 / count`:

```python
if random.randrange(count) == 0:
    answer = i
```

At the end:

```python
return answer
```

The answer is uniformly random among all matching indices.

## Testing

Randomized code should be tested by checking validity and rough distribution, not by expecting one fixed output.

```python
from collections import Counter

def test_solution():
    s = Solution([1, 2, 3, 3, 3])

    for _ in range(100):
        assert s.pick(1) == 0
        assert s.pick(2) == 1
        assert s.pick(3) in {2, 3, 4}

    counts = Counter(s.pick(3) for _ in range(6000))

    assert set(counts) == {2, 3, 4}

    for index in [2, 3, 4]:
        assert 1500 <= counts[index] <= 2500

    print("all tests passed")

test_solution()
```

Test meaning:

| Test | Why |
|---|---|
| `pick(1)` | Only one valid index |
| `pick(2)` | Only one valid index |
| `pick(3)` | Must return one of several valid indices |
| Many calls to `pick(3)` | Basic sanity check for uniform randomness |

