# LeetCode 677: Map Sum Pairs

## Problem Restatement

We need to design a data structure called `MapSum`.

It supports two operations:

| Operation | Meaning |
|---|---|
| `insert(key, val)` | Store the value `val` for the string `key` |
| `sum(prefix)` | Return the total value of all keys that start with `prefix` |

If a key already exists, the new value replaces the old value. It does not add to the old value.

For example:

```python
insert("apple", 3)
sum("ap")          # 3
insert("app", 2)
sum("ap")          # 5
```

After inserting `"apple"` and `"app"`, both keys start with `"ap"`, so the sum is:

```text
3 + 2 = 5
```

## Input and Output

| Method | Input | Output |
|---|---|---|
| `MapSum()` | None | Creates an empty object |
| `insert(key, val)` | A string `key` and integer `val` | `None` |
| `sum(prefix)` | A string `prefix` | Integer sum |

Important constraints:

| Constraint | Meaning |
|---|---|
| `1 <= key.length, prefix.length <= 50` | Strings are short |
| `key` and `prefix` contain lowercase English letters | Trie is natural here |
| `1 <= val <= 1000` | Values are positive |
| At most `50` calls | Even simple solutions can pass |

## Examples

Example:

```python
mapSum = MapSum()
mapSum.insert("apple", 3)
mapSum.sum("ap")        # 3
mapSum.insert("app", 2)
mapSum.sum("ap")        # 5
```

Explanation:

After:

```python
insert("apple", 3)
```

we have:

| Key | Value |
|---|---:|
| `"apple"` | `3` |

So:

```python
sum("ap") == 3
```

Then after:

```python
insert("app", 2)
```

we have:

| Key | Value |
|---|---:|
| `"apple"` | `3` |
| `"app"` | `2` |

Both keys start with `"ap"`.

So:

```python
sum("ap") == 3 + 2 == 5
```

Now consider an update:

```python
mapSum.insert("apple", 3)
mapSum.insert("apple", 2)
mapSum.sum("ap")        # 2
```

The second insert replaces the value of `"apple"`.

It does not make the value `5`.

## First Thought: Store All Keys in a Hash Map

The simplest solution is to store each key and its value in a dictionary.

For `insert`, we assign:

```python
self.values[key] = val
```

For `sum(prefix)`, we scan every key and check whether it starts with the prefix.

```python
class MapSum:

    def __init__(self):
        self.values = {}

    def insert(self, key: str, val: int) -> None:
        self.values[key] = val

    def sum(self, prefix: str) -> int:
        total = 0

        for key, val in self.values.items():
            if key.startswith(prefix):
                total += val

        return total
```

This is correct and simple.

## Problem With the Direct Solution

The direct solution makes `sum(prefix)` expensive.

If we have `n` keys and each key has length up to `m`, then one sum operation may need to inspect all keys.

The time complexity is:

```text
O(n * m)
```

For this problem's small constraints, that can still pass. But the intended design is to make prefix queries fast.

A trie is a better structure because each node naturally represents a prefix.

## Key Insight

A trie stores strings character by character.

Every node represents a prefix.

For example, after inserting `"apple"`:

```text
a -> p -> p -> l -> e
```

The node for `"a"` represents all words starting with `"a"`.

The node for `"ap"` represents all words starting with `"ap"`.

The node for `"app"` represents all words starting with `"app"`.

If each trie node stores the sum of all values under that prefix, then `sum(prefix)` only needs to walk down the prefix and return the stored value.

For example, after inserting:

```python
insert("apple", 3)
insert("app", 2)
```

the trie prefix sums include:

| Prefix | Sum |
|---|---:|
| `"a"` | `5` |
| `"ap"` | `5` |
| `"app"` | `5` |
| `"appl"` | `3` |
| `"apple"` | `3` |

Then:

```python
sum("ap") == 5
```

can be answered immediately after reaching the `"ap"` node.

## Handling Updates

The main detail is key replacement.

Suppose we do:

```python
insert("apple", 3)
insert("apple", 2)
```

The value of `"apple"` changes from `3` to `2`.

The prefix sums should decrease by `1`, not increase by `2`.

So we keep another hash map:

```python
key_to_value
```

This stores the current value of each key.

When inserting a key, compute:

```python
delta = new_value - old_value
```

If the key is new, the old value is `0`.

Then add `delta` to every trie node on the path of that key.

Example:

```python
insert("apple", 3)
```

Old value is `0`, so:

```text
delta = 3 - 0 = 3
```

Every prefix node for `"apple"` gets `+3`.

Now update:

```python
insert("apple", 2)
```

Old value is `3`, so:

```text
delta = 2 - 3 = -1
```

Every prefix node for `"apple"` gets `-1`.

This keeps all prefix sums correct.

## Algorithm

Store two things:

| Structure | Purpose |
|---|---|
| `root` | Root of the trie |
| `values` | Hash map from key to its current value |

Each trie node stores:

| Field | Meaning |
|---|---|
| `children` | Map from character to child node |
| `score` | Sum of all key values with this prefix |

For `insert(key, val)`:

1. Get the previous value of `key`, or `0` if it is new.
2. Compute `delta = val - previous`.
3. Store the new value in `values`.
4. Walk through the trie using the characters of `key`.
5. Create missing nodes as needed.
6. Add `delta` to every node on the path.

For `sum(prefix)`:

1. Start at the root.
2. Walk through the trie using the characters of `prefix`.
3. If a character is missing, return `0`.
4. After consuming the prefix, return the current node's `score`.

## Correctness

Each trie node corresponds to one prefix.

When we insert or update a key, every prefix of that key lies on the path from the root through the key's characters.

The algorithm adds `delta` to exactly those prefix nodes.

If the key is new, `delta` is the inserted value, so every prefix of that key receives that value.

If the key already exists, `delta` is the difference between the new value and the old value. Adding this difference changes every affected prefix sum from the old contribution to the new contribution.

No unrelated prefix is changed because the algorithm only walks along the characters of the inserted key.

Therefore, after every insertion, each trie node's `score` equals the sum of all values whose keys have that node's prefix.

For `sum(prefix)`, the algorithm walks to the trie node representing `prefix`.

If the node does not exist, no stored key has that prefix, so the correct answer is `0`.

If the node exists, its `score` is exactly the sum of all values whose keys start with `prefix`.

So `sum(prefix)` returns the correct value.

## Complexity

Let:

| Symbol | Meaning |
|---|---|
| `k` | Length of the inserted key |
| `p` | Length of the searched prefix |
| `T` | Total number of characters inserted across all distinct keys |

| Operation | Time | Space |
|---|---:|---:|
| `MapSum()` | `O(1)` | `O(1)` |
| `insert(key, val)` | `O(k)` | `O(k)` in the worst case for new trie nodes |
| `sum(prefix)` | `O(p)` | `O(1)` |

Total trie space is:

```text
O(T)
```

The hash map `values` stores one entry per distinct key.

## Implementation

```python
class TrieNode:
    def __init__(self):
        self.children = {}
        self.score = 0

class MapSum:

    def __init__(self):
        self.root = TrieNode()
        self.values = {}

    def insert(self, key: str, val: int) -> None:
        old_val = self.values.get(key, 0)
        delta = val - old_val
        self.values[key] = val

        node = self.root

        for ch in key:
            if ch not in node.children:
                node.children[ch] = TrieNode()

            node = node.children[ch]
            node.score += delta

    def sum(self, prefix: str) -> int:
        node = self.root

        for ch in prefix:
            if ch not in node.children:
                return 0

            node = node.children[ch]

        return node.score
```

## Code Explanation

The trie node is small:

```python
class TrieNode:
    def __init__(self):
        self.children = {}
        self.score = 0
```

`children` stores the outgoing edges from this node.

`score` stores the sum of all values for keys that share this prefix.

The `MapSum` object keeps the root trie node and the current values of keys:

```python
self.root = TrieNode()
self.values = {}
```

The `values` hash map is required because `insert` replaces old values.

In `insert`, we compute the difference:

```python
old_val = self.values.get(key, 0)
delta = val - old_val
```

Then we update the stored value:

```python
self.values[key] = val
```

After that, we walk through the trie:

```python
for ch in key:
```

If the next node does not exist, we create it:

```python
if ch not in node.children:
    node.children[ch] = TrieNode()
```

Then we move to the child and update its prefix sum:

```python
node = node.children[ch]
node.score += delta
```

In `sum`, we walk through the prefix.

If any character is missing:

```python
return 0
```

because no key starts with that prefix.

If the full prefix exists, we return:

```python
return node.score
```

## Testing

```python
def run_tests():
    map_sum = MapSum()

    map_sum.insert("apple", 3)
    assert map_sum.sum("ap") == 3

    map_sum.insert("app", 2)
    assert map_sum.sum("ap") == 5
    assert map_sum.sum("app") == 5
    assert map_sum.sum("apple") == 3
    assert map_sum.sum("b") == 0

    map_sum.insert("apple", 2)
    assert map_sum.sum("ap") == 4
    assert map_sum.sum("app") == 4
    assert map_sum.sum("apple") == 2

    map_sum.insert("ape", 4)
    assert map_sum.sum("ap") == 8
    assert map_sum.sum("ape") == 4

    print("all tests passed")

run_tests()
```

Test meaning:

| Test | Expected | Why |
|---|---:|---|
| Insert `"apple" = 3`, sum `"ap"` | `3` | One matching key |
| Insert `"app" = 2`, sum `"ap"` | `5` | Two matching keys |
| Sum `"app"` | `5` | `"app"` and `"apple"` both match |
| Sum `"apple"` | `3` | Only `"apple"` matches |
| Sum `"b"` | `0` | No key has that prefix |
| Update `"apple"` from `3` to `2` | Prefix sums decrease by `1` | Insert replaces old value |
| Insert `"ape" = 4`, sum `"ap"` | `8` | Three keys share the prefix |

