Skip to content

LeetCode 382: Linked List Random Node

A clear explanation of selecting a random linked list node with equal probability using reservoir sampling.

Problem Restatement

We are given the head of a singly linked list.

We need to implement a class with one main operation:

getRandom()

This operation returns the value of a random node from the linked list.

Every node must have the same probability of being chosen.

The linked list has at least one node, so getRandom() always has a valid value to return. The follow-up asks whether we can solve it efficiently when the linked list is extremely large and its length is unknown, without using extra space.

Input and Output

MethodInputOutput
Solution(head)Head of a singly linked listInitializes the object
getRandom()NoneA random node value

Constraints:

ItemConstraint
Number of nodes1 <= n <= 10^4
Node value-10^4 <= Node.val <= 10^4
Calls to getRandom()At most 10^4

Examples

Suppose the linked list is:

1 -> 2 -> 3

A call to:

getRandom()

should return one of:

1, 2, 3

Each value should have probability:

1 / 3

So across many calls, the output should look roughly balanced.

It may return:

1, 3, 2, 2, 3

or any other random sequence where each call chooses independently.

First Thought: Store Values in an Array

The simplest approach is to traverse the linked list once in the constructor and store all node values in an array.

Then getRandom() can choose a random index.

import random

class Solution:

    def __init__(self, head: Optional[ListNode]):
        self.values = []

        curr = head
        while curr:
            self.values.append(curr.val)
            curr = curr.next

    def getRandom(self) -> int:
        return random.choice(self.values)

This solution is simple and valid for the base constraints.

But it uses O(n) extra space.

The follow-up asks for a solution without extra space, especially when the list is very large or the length is unknown.

Key Insight

We can choose a random node while scanning the linked list once.

This method is called reservoir sampling.

The idea is simple:

When we have seen i nodes, each of those i nodes should have probability 1 / i of being the current answer.

So when we visit the ith node:

  1. Choose it with probability 1 / i.
  2. Otherwise keep the previous answer.

For example:

Node positionProbability of replacing answer
1st node1 / 1
2nd node1 / 2
3rd node1 / 3
4th node1 / 4

At the end, every node has equal probability.

Algorithm

Store the head pointer in the constructor.

self.head = head

For each call to getRandom():

  1. Set curr = self.head.
  2. Set count = 0.
  3. Set answer = None.
  4. Traverse the list.
  5. For each node:
    • Increase count.
    • Replace answer with the current node value with probability 1 / count.
  6. Return answer.

In Python, this probability can be implemented with:

if random.randrange(count) == 0:
    answer = curr.val

random.randrange(count) returns one integer from 0 to count - 1.

So the condition is true with probability 1 / count.

Correctness

We need to show that every node has probability 1 / n of being returned, where n is the linked list length.

Consider the node at position i.

When the algorithm visits this node, it selects it with probability:

1 / i

After that, it must survive all later replacement chances.

At node i + 1, it is not replaced with probability:

i / (i + 1)

At node i + 2, it is not replaced with probability:

(i + 1) / (i + 2)

This continues until node n.

So the final probability that node i is returned is:

(1 / i) * (i / (i + 1)) * ((i + 1) / (i + 2)) * ... * ((n - 1) / n)

Everything cancels except:

1 / n

Therefore every node has the same probability of being selected.

Complexity

Let n be the number of nodes.

OperationTimeSpace
ConstructorO(1)O(1)
getRandom()O(n)O(1)

This is useful when we want to avoid storing the whole linked list in memory.

Implementation

import random
from typing import Optional

# Definition for singly-linked list.
# class ListNode:
#     def __init__(self, val=0, next=None):
#         self.val = val
#         self.next = next

class Solution:

    def __init__(self, head: Optional[ListNode]):
        self.head = head

    def getRandom(self) -> int:
        curr = self.head
        count = 0
        answer = None

        while curr:
            count += 1

            if random.randrange(count) == 0:
                answer = curr.val

            curr = curr.next

        return answer

Code Explanation

The constructor stores the head pointer:

self.head = head

We do not copy the list into an array.

Each call to getRandom() starts from the head:

curr = self.head

The variable count records how many nodes we have seen:

count = 0

The variable answer stores the current sampled value:

answer = None

For every node, we increase the count:

count += 1

Then we replace the answer with probability 1 / count:

if random.randrange(count) == 0:
    answer = curr.val

Finally, we move to the next node:

curr = curr.next

After the loop, answer is a uniformly selected node value.

Testing

Randomized algorithms are harder to test with exact equality.

We can still test important properties:

  1. The returned value must be one of the linked list values.
  2. With one node, the result must always be that node.
  3. Across many calls, all values should appear.
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

def build_list(values):
    dummy = ListNode()
    curr = dummy

    for value in values:
        curr.next = ListNode(value)
        curr = curr.next

    return dummy.next

def test_solution():
    head = build_list([1, 2, 3])
    s = Solution(head)

    seen = set()
    for _ in range(1000):
        value = s.getRandom()
        assert value in {1, 2, 3}
        seen.add(value)

    assert seen == {1, 2, 3}

    single = Solution(build_list([42]))
    for _ in range(100):
        assert single.getRandom() == 42

    negative = Solution(build_list([-10, 0, 10]))
    for _ in range(100):
        assert negative.getRandom() in {-10, 0, 10}

    print("all tests passed")

test_solution()

Test meaning:

TestWhy
Values always in {1, 2, 3}Confirms returned value comes from the list
All values appear after many callsBasic sanity check for randomness
Single-node listMust always return the only value
Negative valuesConfirms node values do not affect sampling