Design a map that supports key-value insertion and prefix-sum queries using a hash map and trie.
Problem Restatement
We need to design a data structure called MapSum.
It supports two operations:
| Operation | Meaning |
|---|---|
insert(key, val) | Store the value val for the string key |
sum(prefix) | Return the total value of all keys that start with prefix |
If a key already exists, the new value replaces the old value. It does not add to the old value.
For example:
insert("apple", 3)
sum("ap") # 3
insert("app", 2)
sum("ap") # 5After inserting "apple" and "app", both keys start with "ap", so the sum is:
3 + 2 = 5Input and Output
| Method | Input | Output |
|---|---|---|
MapSum() | None | Creates an empty object |
insert(key, val) | A string key and integer val | None |
sum(prefix) | A string prefix | Integer sum |
Important constraints:
| Constraint | Meaning |
|---|---|
1 <= key.length, prefix.length <= 50 | Strings are short |
key and prefix contain lowercase English letters | Trie is natural here |
1 <= val <= 1000 | Values are positive |
At most 50 calls | Even simple solutions can pass |
Examples
Example:
mapSum = MapSum()
mapSum.insert("apple", 3)
mapSum.sum("ap") # 3
mapSum.insert("app", 2)
mapSum.sum("ap") # 5Explanation:
After:
insert("apple", 3)we have:
| Key | Value |
|---|---|
"apple" | 3 |
So:
sum("ap") == 3Then after:
insert("app", 2)we have:
| Key | Value |
|---|---|
"apple" | 3 |
"app" | 2 |
Both keys start with "ap".
So:
sum("ap") == 3 + 2 == 5Now consider an update:
mapSum.insert("apple", 3)
mapSum.insert("apple", 2)
mapSum.sum("ap") # 2The second insert replaces the value of "apple".
It does not make the value 5.
First Thought: Store All Keys in a Hash Map
The simplest solution is to store each key and its value in a dictionary.
For insert, we assign:
self.values[key] = valFor sum(prefix), we scan every key and check whether it starts with the prefix.
class MapSum:
def __init__(self):
self.values = {}
def insert(self, key: str, val: int) -> None:
self.values[key] = val
def sum(self, prefix: str) -> int:
total = 0
for key, val in self.values.items():
if key.startswith(prefix):
total += val
return totalThis is correct and simple.
Problem With the Direct Solution
The direct solution makes sum(prefix) expensive.
If we have n keys and each key has length up to m, then one sum operation may need to inspect all keys.
The time complexity is:
O(n * m)For this problem’s small constraints, that can still pass. But the intended design is to make prefix queries fast.
A trie is a better structure because each node naturally represents a prefix.
Key Insight
A trie stores strings character by character.
Every node represents a prefix.
For example, after inserting "apple":
a -> p -> p -> l -> eThe node for "a" represents all words starting with "a".
The node for "ap" represents all words starting with "ap".
The node for "app" represents all words starting with "app".
If each trie node stores the sum of all values under that prefix, then sum(prefix) only needs to walk down the prefix and return the stored value.
For example, after inserting:
insert("apple", 3)
insert("app", 2)the trie prefix sums include:
| Prefix | Sum |
|---|---|
"a" | 5 |
"ap" | 5 |
"app" | 5 |
"appl" | 3 |
"apple" | 3 |
Then:
sum("ap") == 5can be answered immediately after reaching the "ap" node.
Handling Updates
The main detail is key replacement.
Suppose we do:
insert("apple", 3)
insert("apple", 2)The value of "apple" changes from 3 to 2.
The prefix sums should decrease by 1, not increase by 2.
So we keep another hash map:
key_to_valueThis stores the current value of each key.
When inserting a key, compute:
delta = new_value - old_valueIf the key is new, the old value is 0.
Then add delta to every trie node on the path of that key.
Example:
insert("apple", 3)Old value is 0, so:
delta = 3 - 0 = 3Every prefix node for "apple" gets +3.
Now update:
insert("apple", 2)Old value is 3, so:
delta = 2 - 3 = -1Every prefix node for "apple" gets -1.
This keeps all prefix sums correct.
Algorithm
Store two things:
| Structure | Purpose |
|---|---|
root | Root of the trie |
values | Hash map from key to its current value |
Each trie node stores:
| Field | Meaning |
|---|---|
children | Map from character to child node |
score | Sum of all key values with this prefix |
For insert(key, val):
- Get the previous value of
key, or0if it is new. - Compute
delta = val - previous. - Store the new value in
values. - Walk through the trie using the characters of
key. - Create missing nodes as needed.
- Add
deltato every node on the path.
For sum(prefix):
- Start at the root.
- Walk through the trie using the characters of
prefix. - If a character is missing, return
0. - After consuming the prefix, return the current node’s
score.
Correctness
Each trie node corresponds to one prefix.
When we insert or update a key, every prefix of that key lies on the path from the root through the key’s characters.
The algorithm adds delta to exactly those prefix nodes.
If the key is new, delta is the inserted value, so every prefix of that key receives that value.
If the key already exists, delta is the difference between the new value and the old value. Adding this difference changes every affected prefix sum from the old contribution to the new contribution.
No unrelated prefix is changed because the algorithm only walks along the characters of the inserted key.
Therefore, after every insertion, each trie node’s score equals the sum of all values whose keys have that node’s prefix.
For sum(prefix), the algorithm walks to the trie node representing prefix.
If the node does not exist, no stored key has that prefix, so the correct answer is 0.
If the node exists, its score is exactly the sum of all values whose keys start with prefix.
So sum(prefix) returns the correct value.
Complexity
Let:
| Symbol | Meaning |
|---|---|
k | Length of the inserted key |
p | Length of the searched prefix |
T | Total number of characters inserted across all distinct keys |
| Operation | Time | Space |
|---|---|---|
MapSum() | O(1) | O(1) |
insert(key, val) | O(k) | O(k) in the worst case for new trie nodes |
sum(prefix) | O(p) | O(1) |
Total trie space is:
O(T)The hash map values stores one entry per distinct key.
Implementation
class TrieNode:
def __init__(self):
self.children = {}
self.score = 0
class MapSum:
def __init__(self):
self.root = TrieNode()
self.values = {}
def insert(self, key: str, val: int) -> None:
old_val = self.values.get(key, 0)
delta = val - old_val
self.values[key] = val
node = self.root
for ch in key:
if ch not in node.children:
node.children[ch] = TrieNode()
node = node.children[ch]
node.score += delta
def sum(self, prefix: str) -> int:
node = self.root
for ch in prefix:
if ch not in node.children:
return 0
node = node.children[ch]
return node.scoreCode Explanation
The trie node is small:
class TrieNode:
def __init__(self):
self.children = {}
self.score = 0children stores the outgoing edges from this node.
score stores the sum of all values for keys that share this prefix.
The MapSum object keeps the root trie node and the current values of keys:
self.root = TrieNode()
self.values = {}The values hash map is required because insert replaces old values.
In insert, we compute the difference:
old_val = self.values.get(key, 0)
delta = val - old_valThen we update the stored value:
self.values[key] = valAfter that, we walk through the trie:
for ch in key:If the next node does not exist, we create it:
if ch not in node.children:
node.children[ch] = TrieNode()Then we move to the child and update its prefix sum:
node = node.children[ch]
node.score += deltaIn sum, we walk through the prefix.
If any character is missing:
return 0because no key starts with that prefix.
If the full prefix exists, we return:
return node.scoreTesting
def run_tests():
map_sum = MapSum()
map_sum.insert("apple", 3)
assert map_sum.sum("ap") == 3
map_sum.insert("app", 2)
assert map_sum.sum("ap") == 5
assert map_sum.sum("app") == 5
assert map_sum.sum("apple") == 3
assert map_sum.sum("b") == 0
map_sum.insert("apple", 2)
assert map_sum.sum("ap") == 4
assert map_sum.sum("app") == 4
assert map_sum.sum("apple") == 2
map_sum.insert("ape", 4)
assert map_sum.sum("ap") == 8
assert map_sum.sum("ape") == 4
print("all tests passed")
run_tests()Test meaning:
| Test | Expected | Why |
|---|---|---|
Insert "apple" = 3, sum "ap" | 3 | One matching key |
Insert "app" = 2, sum "ap" | 5 | Two matching keys |
Sum "app" | 5 | "app" and "apple" both match |
Sum "apple" | 3 | Only "apple" matches |
Sum "b" | 0 | No key has that prefix |
Update "apple" from 3 to 2 | Prefix sums decrease by 1 | Insert replaces old value |
Insert "ape" = 4, sum "ap" | 8 | Three keys share the prefix |