Skip to content

LeetCode 432: All O'one Data Structure

Design a data structure that supports increment, decrement, get minimum key, and get maximum key in average O(1) time.

Problem Restatement

We need to design a data structure that stores string counts.

It must support these operations:

OperationMeaning
inc(key)Increase key count by 1; insert it with count 1 if missing
dec(key)Decrease key count by 1; remove it if count becomes 0
getMaxKey()Return any key with the largest count
getMinKey()Return any key with the smallest count

If the data structure is empty, both getMaxKey() and getMinKey() should return an empty string.

Each function must run in O(1) average time. The official statement also guarantees that key exists before dec(key) is called.

Input and Output

MethodInputOutput
AllOne()noneinitializes the object
inc(key)string keynone
dec(key)string keynone
getMaxKey()noneany key with maximum count, or ""
getMinKey()noneany key with minimum count, or ""

The important constraint is not only correctness. The data structure must keep all operations in average constant time.

Examples

Suppose we run:

obj = AllOne()
obj.inc("hello")
obj.inc("hello")
obj.inc("world")

Now the counts are:

KeyCount
hello2
world1

So:

obj.getMaxKey()

can return:

"hello"

and:

obj.getMinKey()

can return:

"world"

Now run:

obj.dec("hello")

The counts become:

KeyCount
hello1
world1

Now either key can be returned by getMaxKey() or getMinKey() because both have the same count.

First Thought: Hash Map Only

The simplest data structure is:

count = {
    "hello": 2,
    "world": 1,
}

Then:

inc(key)
dec(key)

are easy.

But getMaxKey() and getMinKey() become slow because we need to scan every key to find the largest or smallest count.

That takes:

O(n)

where n is the number of keys.

The problem requires O(1) average time for every operation, so a hash map alone is not enough.

Key Insight

We need two things at the same time:

NeedData Structure
Find a key’s current count quicklyHash map
Find current minimum and maximum count quicklyOrdered count buckets

The standard solution uses a doubly linked list of buckets.

Each bucket stores:

FieldMeaning
countThe count represented by this bucket
keysA set of keys that currently have this count
prevPrevious smaller count bucket
nextNext larger count bucket

The linked list is sorted by count:

head <-> count 1 <-> count 2 <-> count 5 <-> tail

We also keep:

key_to_bucket[key] = bucket

So when a key changes from count c to count c + 1, we already know where it is. We only need to move it to the adjacent count bucket.

If that adjacent bucket does not exist, we create it.

Because a key only changes by 1, it never needs to jump far across the list.

That is the reason inc and dec can be O(1).

Data Structure

We use two sentinel nodes:

head <-> ... <-> tail

head.next is always the minimum count bucket.

tail.prev is always the maximum count bucket.

The sentinels avoid special cases when inserting or deleting buckets at the ends.

Bucket node:

class Bucket:
    def __init__(self, count: int):
        self.count = count
        self.keys = set()
        self.prev = None
        self.next = None

Main structure:

self.key_to_bucket = {}
self.head = Bucket(0)
self.tail = Bucket(0)
self.head.next = self.tail
self.tail.prev = self.head

Algorithm

inc(key)

If the key is new:

  1. It should move into count 1.
  2. If the first bucket already has count 1, use it.
  3. Otherwise insert a new count 1 bucket after head.
  4. Add the key to that bucket.
  5. Store key_to_bucket[key].

If the key already exists:

  1. Find its current bucket.
  2. Its new count is bucket.count + 1.
  3. If the next bucket has that count, move the key there.
  4. Otherwise create a new bucket after the current bucket.
  5. Remove the key from the old bucket.
  6. Delete the old bucket if it becomes empty.

dec(key)

The problem guarantees the key exists.

If its current count is 1:

  1. Remove the key from its bucket.
  2. Delete the key from key_to_bucket.
  3. Remove the bucket if empty.

If its current count is greater than 1:

  1. Its new count is bucket.count - 1.
  2. If the previous bucket has that count, move the key there.
  3. Otherwise create a new bucket before the current bucket.
  4. Remove the key from the old bucket.
  5. Delete the old bucket if empty.

getMaxKey()

If the structure is empty, return "".

Otherwise, return any key from:

self.tail.prev.keys

getMinKey()

If the structure is empty, return "".

Otherwise, return any key from:

self.head.next.keys

Correctness

The linked list stores one bucket for each count that currently appears among the keys. The buckets are kept in increasing count order.

When inc(key) is called for a new key, the key is placed into the count 1 bucket. If that bucket does not exist, the algorithm creates it immediately after head, which is the correct position for the smallest positive count.

When inc(key) is called for an existing key in bucket c, the key’s count becomes c + 1. Since no count between c and c + 1 exists, the correct destination is either the next bucket if it already has count c + 1, or a newly inserted bucket immediately after the current bucket.

The same reasoning applies to dec(key). A key with count c moves to count c - 1. The correct destination is either the previous bucket if it has count c - 1, or a newly inserted bucket immediately before the current bucket. If the count becomes 0, the key is removed.

Whenever a bucket loses its last key, the algorithm removes that bucket. Therefore every remaining bucket represents at least one existing key.

Since the bucket list is always sorted by count, the first real bucket contains keys with the minimum count, and the last real bucket contains keys with the maximum count. Thus getMinKey() and getMaxKey() return correct keys.

Complexity

OperationTimeWhy
incO(1) averageHash lookup, set add/remove, constant linked-list edits
decO(1) averageHash lookup, set add/remove, constant linked-list edits
getMaxKeyO(1) averageRead from last bucket
getMinKeyO(1) averageRead from first bucket

Space complexity is:

O(n)

where n is the number of keys.

Implementation

class Bucket:
    def __init__(self, count: int):
        self.count = count
        self.keys = set()
        self.prev = None
        self.next = None

class AllOne:
    def __init__(self):
        self.key_to_bucket = {}

        self.head = Bucket(0)
        self.tail = Bucket(0)

        self.head.next = self.tail
        self.tail.prev = self.head

    def _insert_after(self, node: Bucket, new_node: Bucket) -> None:
        nxt = node.next

        node.next = new_node
        new_node.prev = node

        new_node.next = nxt
        nxt.prev = new_node

    def _remove_bucket(self, node: Bucket) -> None:
        prev_node = node.prev
        next_node = node.next

        prev_node.next = next_node
        next_node.prev = prev_node

    def inc(self, key: str) -> None:
        if key not in self.key_to_bucket:
            first = self.head.next

            if first is self.tail or first.count != 1:
                first = Bucket(1)
                self._insert_after(self.head, first)

            first.keys.add(key)
            self.key_to_bucket[key] = first
            return

        curr = self.key_to_bucket[key]
        new_count = curr.count + 1
        nxt = curr.next

        if nxt is self.tail or nxt.count != new_count:
            nxt = Bucket(new_count)
            self._insert_after(curr, nxt)

        nxt.keys.add(key)
        self.key_to_bucket[key] = nxt

        curr.keys.remove(key)

        if not curr.keys:
            self._remove_bucket(curr)

    def dec(self, key: str) -> None:
        curr = self.key_to_bucket[key]

        if curr.count == 1:
            curr.keys.remove(key)
            del self.key_to_bucket[key]

            if not curr.keys:
                self._remove_bucket(curr)

            return

        new_count = curr.count - 1
        prev_node = curr.prev

        if prev_node is self.head or prev_node.count != new_count:
            prev_node = Bucket(new_count)
            self._insert_after(curr.prev, prev_node)

        prev_node.keys.add(key)
        self.key_to_bucket[key] = prev_node

        curr.keys.remove(key)

        if not curr.keys:
            self._remove_bucket(curr)

    def getMaxKey(self) -> str:
        if self.tail.prev is self.head:
            return ""

        return next(iter(self.tail.prev.keys))

    def getMinKey(self) -> str:
        if self.head.next is self.tail:
            return ""

        return next(iter(self.head.next.keys))

Code Explanation

The Bucket class represents one count value:

self.count = count
self.keys = set()

All keys inside the same bucket have the same count.

The doubly linked list uses two sentinels:

self.head = Bucket(0)
self.tail = Bucket(0)

The real buckets live between them.

The helper _insert_after inserts a bucket in constant time:

def _insert_after(self, node, new_node):

The helper _remove_bucket removes an empty bucket in constant time.

In inc, a new key always goes to count 1.

if key not in self.key_to_bucket:

If no count 1 bucket exists, we create one after head.

For an existing key, we move it from count c to count c + 1.

new_count = curr.count + 1

The next bucket is the only possible existing destination.

In dec, a key with count 1 is removed completely.

if curr.count == 1:

Otherwise, it moves from count c to count c - 1.

The previous bucket is the only possible existing destination.

The min and max methods are direct because the linked list is sorted:

self.head.next
self.tail.prev

Testing

def run_tests():
    obj = AllOne()

    assert obj.getMaxKey() == ""
    assert obj.getMinKey() == ""

    obj.inc("hello")
    obj.inc("hello")
    obj.inc("world")

    assert obj.getMaxKey() == "hello"
    assert obj.getMinKey() == "world"

    obj.dec("hello")

    assert obj.getMaxKey() in {"hello", "world"}
    assert obj.getMinKey() in {"hello", "world"}

    obj.dec("hello")

    assert obj.getMaxKey() == "world"
    assert obj.getMinKey() == "world"

    obj.dec("world")

    assert obj.getMaxKey() == ""
    assert obj.getMinKey() == ""

    obj.inc("a")
    obj.inc("b")
    obj.inc("b")
    obj.inc("c")
    obj.inc("c")
    obj.inc("c")

    assert obj.getMinKey() == "a"
    assert obj.getMaxKey() == "c"

    obj.dec("c")
    obj.dec("c")

    assert obj.getMaxKey() == "b"
    assert obj.getMinKey() in {"a", "c"}

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
Empty structureChecks empty-string return
Repeated incrementChecks max bucket updates
Decrement to equal countChecks ties
Decrement to zeroChecks key removal
Multiple bucketsChecks min and max at opposite ends
Bucket deletionChecks empty bucket cleanup