# LeetCode 563: Binary Tree Tilt

## Problem Restatement

We are given the root of a binary tree.

For each node, its tilt is the absolute difference between:

| Side | Meaning |
|---|---|
| Left subtree sum | Sum of all values in the left subtree |
| Right subtree sum | Sum of all values in the right subtree |

If a node has no left child, its left subtree sum is `0`.

If a node has no right child, its right subtree sum is `0`.

Return the sum of the tilt of every node in the tree.

The number of nodes is in the range `[0, 10^4]`, and each node value is between `-1000` and `1000`.

## Input and Output

| Item | Meaning |
|---|---|
| Input | Root of a binary tree |
| Output | Sum of all node tilts |
| Empty tree | Return `0` |
| Traversal needed | Bottom-up DFS |

Example function shape:

```python
def findTilt(root: Optional[TreeNode]) -> int:
    ...
```

## Examples

Example 1:

```python
root = [1, 2, 3]
```

The tree is:

```python
    1
   / \
  2   3
```

For node `2`:

```python
left subtree sum = 0
right subtree sum = 0
tilt = abs(0 - 0) = 0
```

For node `3`:

```python
left subtree sum = 0
right subtree sum = 0
tilt = abs(0 - 0) = 0
```

For node `1`:

```python
left subtree sum = 2
right subtree sum = 3
tilt = abs(2 - 3) = 1
```

So the answer is:

```python
1
```

Example 2:

```python
root = [4, 2, 9, 3, 5, None, 7]
```

The tree is:

```python
      4
     / \
    2   9
   / \   \
  3   5   7
```

Node tilts:

| Node | Left sum | Right sum | Tilt |
|---|---:|---:|---:|
| `3` | `0` | `0` | `0` |
| `5` | `0` | `0` | `0` |
| `7` | `0` | `0` | `0` |
| `2` | `3` | `5` | `2` |
| `9` | `0` | `7` | `7` |
| `4` | `10` | `16` | `6` |

The total tilt is:

```python
0 + 0 + 0 + 2 + 7 + 6 = 15
```

So the answer is:

```python
15
```

## First Thought: Compute Each Subtree Sum Separately

A direct idea is:

1. Visit every node.
2. For each node, compute the sum of its left subtree.
3. Compute the sum of its right subtree.
4. Add the absolute difference to the answer.

This works, but it can recompute the same subtree sums many times.

For example, the sum of a lower subtree may be needed by many ancestors.

We need to compute each subtree sum once.

## Key Insight

The tilt of a node depends on information from its children.

Specifically, before we can compute the tilt of a node, we need:

```python
left_subtree_sum
right_subtree_sum
```

That means this is a bottom-up problem.

Use postorder DFS:

1. Compute the left subtree sum.
2. Compute the right subtree sum.
3. Compute the current node tilt.
4. Return the total sum of the current subtree to the parent.

Each recursive call should do two jobs:

| Job | Purpose |
|---|---|
| Return subtree sum | Needed by the parent |
| Add node tilt to answer | Needed for final result |

## Algorithm

Use a helper function `dfs(node)`.

The helper returns the sum of all node values in the subtree rooted at `node`.

1. If `node` is `None`, return `0`.
2. Recursively compute `left_sum = dfs(node.left)`.
3. Recursively compute `right_sum = dfs(node.right)`.
4. Add `abs(left_sum - right_sum)` to the answer.
5. Return `left_sum + right_sum + node.val`.

After DFS finishes, return the accumulated answer.

## Correctness

For a `None` node, the subtree sum is `0`, which matches the problem rule for missing children.

For a real node, the recursive call on the left child returns the sum of all values in the left subtree. The recursive call on the right child returns the sum of all values in the right subtree.

Therefore, the algorithm computes the node tilt as:

```python
abs(left_sum - right_sum)
```

This is exactly the definition of that node's tilt.

Then the helper returns:

```python
left_sum + right_sum + node.val
```

which is exactly the sum of the subtree rooted at the current node.

Because DFS visits every node once and adds that node's tilt exactly once, the accumulated answer is the sum of every tree node's tilt.

## Complexity

Let `n` be the number of nodes in the tree.

| Metric | Value | Why |
|---|---|---|
| Time | `O(n)` | Each node is visited once |
| Space | `O(h)` | The recursion stack stores one call per tree level |

Here `h` is the height of the tree.

In the worst case, a skewed tree has `h = n`.

## Implementation

```python
from typing import Optional

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val: int = 0, left: Optional['TreeNode'] = None, right: Optional['TreeNode'] = None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def findTilt(self, root: Optional['TreeNode']) -> int:
        total_tilt = 0

        def dfs(node: Optional['TreeNode']) -> int:
            nonlocal total_tilt

            if node is None:
                return 0

            left_sum = dfs(node.left)
            right_sum = dfs(node.right)

            total_tilt += abs(left_sum - right_sum)

            return left_sum + right_sum + node.val

        dfs(root)
        return total_tilt
```

## Code Explanation

We keep the answer outside the helper:

```python
total_tilt = 0
```

The helper returns subtree sums:

```python
def dfs(node):
```

The base case handles missing children:

```python
if node is None:
    return 0
```

Then we compute both subtree sums:

```python
left_sum = dfs(node.left)
right_sum = dfs(node.right)
```

Now we can compute the tilt of the current node:

```python
total_tilt += abs(left_sum - right_sum)
```

Finally, we return the current subtree sum:

```python
return left_sum + right_sum + node.val
```

The outer function calls DFS and returns the accumulated total.

## Testing

```python
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def run_tests():
    s = Solution()

    assert s.findTilt(None) == 0

    root = TreeNode(1, TreeNode(2), TreeNode(3))
    assert s.findTilt(root) == 1

    root = TreeNode(
        4,
        TreeNode(2, TreeNode(3), TreeNode(5)),
        TreeNode(9, None, TreeNode(7)),
    )
    assert s.findTilt(root) == 15

    root = TreeNode(1)
    assert s.findTilt(root) == 0

    root = TreeNode(1, TreeNode(2, TreeNode(3)), None)
    assert s.findTilt(root) == 5

    print("all tests passed")

run_tests()
```

| Test | Why |
|---|---|
| `None` | Empty tree |
| `[1, 2, 3]` | Basic sample |
| `[4, 2, 9, 3, 5, None, 7]` | Larger sample |
| Single node | Leaf tilt is zero |
| Left-skewed tree | Checks missing right subtree sums |

