Skip to content

LeetCode 677: Map Sum Pairs

Design a map that supports key-value insertion and prefix-sum queries using a hash map and trie.

Problem Restatement

We need to design a data structure called MapSum.

It supports two operations:

OperationMeaning
insert(key, val)Store the value val for the string key
sum(prefix)Return the total value of all keys that start with prefix

If a key already exists, the new value replaces the old value. It does not add to the old value.

For example:

insert("apple", 3)
sum("ap")          # 3
insert("app", 2)
sum("ap")          # 5

After inserting "apple" and "app", both keys start with "ap", so the sum is:

3 + 2 = 5

Input and Output

MethodInputOutput
MapSum()NoneCreates an empty object
insert(key, val)A string key and integer valNone
sum(prefix)A string prefixInteger sum

Important constraints:

ConstraintMeaning
1 <= key.length, prefix.length <= 50Strings are short
key and prefix contain lowercase English lettersTrie is natural here
1 <= val <= 1000Values are positive
At most 50 callsEven simple solutions can pass

Examples

Example:

mapSum = MapSum()
mapSum.insert("apple", 3)
mapSum.sum("ap")        # 3
mapSum.insert("app", 2)
mapSum.sum("ap")        # 5

Explanation:

After:

insert("apple", 3)

we have:

KeyValue
"apple"3

So:

sum("ap") == 3

Then after:

insert("app", 2)

we have:

KeyValue
"apple"3
"app"2

Both keys start with "ap".

So:

sum("ap") == 3 + 2 == 5

Now consider an update:

mapSum.insert("apple", 3)
mapSum.insert("apple", 2)
mapSum.sum("ap")        # 2

The second insert replaces the value of "apple".

It does not make the value 5.

First Thought: Store All Keys in a Hash Map

The simplest solution is to store each key and its value in a dictionary.

For insert, we assign:

self.values[key] = val

For sum(prefix), we scan every key and check whether it starts with the prefix.

class MapSum:

    def __init__(self):
        self.values = {}

    def insert(self, key: str, val: int) -> None:
        self.values[key] = val

    def sum(self, prefix: str) -> int:
        total = 0

        for key, val in self.values.items():
            if key.startswith(prefix):
                total += val

        return total

This is correct and simple.

Problem With the Direct Solution

The direct solution makes sum(prefix) expensive.

If we have n keys and each key has length up to m, then one sum operation may need to inspect all keys.

The time complexity is:

O(n * m)

For this problem’s small constraints, that can still pass. But the intended design is to make prefix queries fast.

A trie is a better structure because each node naturally represents a prefix.

Key Insight

A trie stores strings character by character.

Every node represents a prefix.

For example, after inserting "apple":

a -> p -> p -> l -> e

The node for "a" represents all words starting with "a".

The node for "ap" represents all words starting with "ap".

The node for "app" represents all words starting with "app".

If each trie node stores the sum of all values under that prefix, then sum(prefix) only needs to walk down the prefix and return the stored value.

For example, after inserting:

insert("apple", 3)
insert("app", 2)

the trie prefix sums include:

PrefixSum
"a"5
"ap"5
"app"5
"appl"3
"apple"3

Then:

sum("ap") == 5

can be answered immediately after reaching the "ap" node.

Handling Updates

The main detail is key replacement.

Suppose we do:

insert("apple", 3)
insert("apple", 2)

The value of "apple" changes from 3 to 2.

The prefix sums should decrease by 1, not increase by 2.

So we keep another hash map:

key_to_value

This stores the current value of each key.

When inserting a key, compute:

delta = new_value - old_value

If the key is new, the old value is 0.

Then add delta to every trie node on the path of that key.

Example:

insert("apple", 3)

Old value is 0, so:

delta = 3 - 0 = 3

Every prefix node for "apple" gets +3.

Now update:

insert("apple", 2)

Old value is 3, so:

delta = 2 - 3 = -1

Every prefix node for "apple" gets -1.

This keeps all prefix sums correct.

Algorithm

Store two things:

StructurePurpose
rootRoot of the trie
valuesHash map from key to its current value

Each trie node stores:

FieldMeaning
childrenMap from character to child node
scoreSum of all key values with this prefix

For insert(key, val):

  1. Get the previous value of key, or 0 if it is new.
  2. Compute delta = val - previous.
  3. Store the new value in values.
  4. Walk through the trie using the characters of key.
  5. Create missing nodes as needed.
  6. Add delta to every node on the path.

For sum(prefix):

  1. Start at the root.
  2. Walk through the trie using the characters of prefix.
  3. If a character is missing, return 0.
  4. After consuming the prefix, return the current node’s score.

Correctness

Each trie node corresponds to one prefix.

When we insert or update a key, every prefix of that key lies on the path from the root through the key’s characters.

The algorithm adds delta to exactly those prefix nodes.

If the key is new, delta is the inserted value, so every prefix of that key receives that value.

If the key already exists, delta is the difference between the new value and the old value. Adding this difference changes every affected prefix sum from the old contribution to the new contribution.

No unrelated prefix is changed because the algorithm only walks along the characters of the inserted key.

Therefore, after every insertion, each trie node’s score equals the sum of all values whose keys have that node’s prefix.

For sum(prefix), the algorithm walks to the trie node representing prefix.

If the node does not exist, no stored key has that prefix, so the correct answer is 0.

If the node exists, its score is exactly the sum of all values whose keys start with prefix.

So sum(prefix) returns the correct value.

Complexity

Let:

SymbolMeaning
kLength of the inserted key
pLength of the searched prefix
TTotal number of characters inserted across all distinct keys
OperationTimeSpace
MapSum()O(1)O(1)
insert(key, val)O(k)O(k) in the worst case for new trie nodes
sum(prefix)O(p)O(1)

Total trie space is:

O(T)

The hash map values stores one entry per distinct key.

Implementation

class TrieNode:
    def __init__(self):
        self.children = {}
        self.score = 0

class MapSum:

    def __init__(self):
        self.root = TrieNode()
        self.values = {}

    def insert(self, key: str, val: int) -> None:
        old_val = self.values.get(key, 0)
        delta = val - old_val
        self.values[key] = val

        node = self.root

        for ch in key:
            if ch not in node.children:
                node.children[ch] = TrieNode()

            node = node.children[ch]
            node.score += delta

    def sum(self, prefix: str) -> int:
        node = self.root

        for ch in prefix:
            if ch not in node.children:
                return 0

            node = node.children[ch]

        return node.score

Code Explanation

The trie node is small:

class TrieNode:
    def __init__(self):
        self.children = {}
        self.score = 0

children stores the outgoing edges from this node.

score stores the sum of all values for keys that share this prefix.

The MapSum object keeps the root trie node and the current values of keys:

self.root = TrieNode()
self.values = {}

The values hash map is required because insert replaces old values.

In insert, we compute the difference:

old_val = self.values.get(key, 0)
delta = val - old_val

Then we update the stored value:

self.values[key] = val

After that, we walk through the trie:

for ch in key:

If the next node does not exist, we create it:

if ch not in node.children:
    node.children[ch] = TrieNode()

Then we move to the child and update its prefix sum:

node = node.children[ch]
node.score += delta

In sum, we walk through the prefix.

If any character is missing:

return 0

because no key starts with that prefix.

If the full prefix exists, we return:

return node.score

Testing

def run_tests():
    map_sum = MapSum()

    map_sum.insert("apple", 3)
    assert map_sum.sum("ap") == 3

    map_sum.insert("app", 2)
    assert map_sum.sum("ap") == 5
    assert map_sum.sum("app") == 5
    assert map_sum.sum("apple") == 3
    assert map_sum.sum("b") == 0

    map_sum.insert("apple", 2)
    assert map_sum.sum("ap") == 4
    assert map_sum.sum("app") == 4
    assert map_sum.sum("apple") == 2

    map_sum.insert("ape", 4)
    assert map_sum.sum("ap") == 8
    assert map_sum.sum("ape") == 4

    print("all tests passed")

run_tests()

Test meaning:

TestExpectedWhy
Insert "apple" = 3, sum "ap"3One matching key
Insert "app" = 2, sum "ap"5Two matching keys
Sum "app"5"app" and "apple" both match
Sum "apple"3Only "apple" matches
Sum "b"0No key has that prefix
Update "apple" from 3 to 2Prefix sums decrease by 1Insert replaces old value
Insert "ape" = 4, sum "ap"8Three keys share the prefix