A clear explanation of Binary Tree Tilt using postorder DFS to compute subtree sums and accumulate tilt.
Problem Restatement
We are given the root of a binary tree.
For each node, its tilt is the absolute difference between:
| Side | Meaning |
|---|---|
| Left subtree sum | Sum of all values in the left subtree |
| Right subtree sum | Sum of all values in the right subtree |
If a node has no left child, its left subtree sum is 0.
If a node has no right child, its right subtree sum is 0.
Return the sum of the tilt of every node in the tree.
The number of nodes is in the range [0, 10^4], and each node value is between -1000 and 1000.
Input and Output
| Item | Meaning |
|---|---|
| Input | Root of a binary tree |
| Output | Sum of all node tilts |
| Empty tree | Return 0 |
| Traversal needed | Bottom-up DFS |
Example function shape:
def findTilt(root: Optional[TreeNode]) -> int:
...Examples
Example 1:
root = [1, 2, 3]The tree is:
1
/ \
2 3For node 2:
left subtree sum = 0
right subtree sum = 0
tilt = abs(0 - 0) = 0For node 3:
left subtree sum = 0
right subtree sum = 0
tilt = abs(0 - 0) = 0For node 1:
left subtree sum = 2
right subtree sum = 3
tilt = abs(2 - 3) = 1So the answer is:
1Example 2:
root = [4, 2, 9, 3, 5, None, 7]The tree is:
4
/ \
2 9
/ \ \
3 5 7Node tilts:
| Node | Left sum | Right sum | Tilt |
|---|---|---|---|
3 | 0 | 0 | 0 |
5 | 0 | 0 | 0 |
7 | 0 | 0 | 0 |
2 | 3 | 5 | 2 |
9 | 0 | 7 | 7 |
4 | 10 | 16 | 6 |
The total tilt is:
0 + 0 + 0 + 2 + 7 + 6 = 15So the answer is:
15First Thought: Compute Each Subtree Sum Separately
A direct idea is:
- Visit every node.
- For each node, compute the sum of its left subtree.
- Compute the sum of its right subtree.
- Add the absolute difference to the answer.
This works, but it can recompute the same subtree sums many times.
For example, the sum of a lower subtree may be needed by many ancestors.
We need to compute each subtree sum once.
Key Insight
The tilt of a node depends on information from its children.
Specifically, before we can compute the tilt of a node, we need:
left_subtree_sum
right_subtree_sumThat means this is a bottom-up problem.
Use postorder DFS:
- Compute the left subtree sum.
- Compute the right subtree sum.
- Compute the current node tilt.
- Return the total sum of the current subtree to the parent.
Each recursive call should do two jobs:
| Job | Purpose |
|---|---|
| Return subtree sum | Needed by the parent |
| Add node tilt to answer | Needed for final result |
Algorithm
Use a helper function dfs(node).
The helper returns the sum of all node values in the subtree rooted at node.
- If
nodeisNone, return0. - Recursively compute
left_sum = dfs(node.left). - Recursively compute
right_sum = dfs(node.right). - Add
abs(left_sum - right_sum)to the answer. - Return
left_sum + right_sum + node.val.
After DFS finishes, return the accumulated answer.
Correctness
For a None node, the subtree sum is 0, which matches the problem rule for missing children.
For a real node, the recursive call on the left child returns the sum of all values in the left subtree. The recursive call on the right child returns the sum of all values in the right subtree.
Therefore, the algorithm computes the node tilt as:
abs(left_sum - right_sum)This is exactly the definition of that node’s tilt.
Then the helper returns:
left_sum + right_sum + node.valwhich is exactly the sum of the subtree rooted at the current node.
Because DFS visits every node once and adds that node’s tilt exactly once, the accumulated answer is the sum of every tree node’s tilt.
Complexity
Let n be the number of nodes in the tree.
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | The recursion stack stores one call per tree level |
Here h is the height of the tree.
In the worst case, a skewed tree has h = n.
Implementation
from typing import Optional
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val: int = 0, left: Optional['TreeNode'] = None, right: Optional['TreeNode'] = None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def findTilt(self, root: Optional['TreeNode']) -> int:
total_tilt = 0
def dfs(node: Optional['TreeNode']) -> int:
nonlocal total_tilt
if node is None:
return 0
left_sum = dfs(node.left)
right_sum = dfs(node.right)
total_tilt += abs(left_sum - right_sum)
return left_sum + right_sum + node.val
dfs(root)
return total_tiltCode Explanation
We keep the answer outside the helper:
total_tilt = 0The helper returns subtree sums:
def dfs(node):The base case handles missing children:
if node is None:
return 0Then we compute both subtree sums:
left_sum = dfs(node.left)
right_sum = dfs(node.right)Now we can compute the tilt of the current node:
total_tilt += abs(left_sum - right_sum)Finally, we return the current subtree sum:
return left_sum + right_sum + node.valThe outer function calls DFS and returns the accumulated total.
Testing
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def run_tests():
s = Solution()
assert s.findTilt(None) == 0
root = TreeNode(1, TreeNode(2), TreeNode(3))
assert s.findTilt(root) == 1
root = TreeNode(
4,
TreeNode(2, TreeNode(3), TreeNode(5)),
TreeNode(9, None, TreeNode(7)),
)
assert s.findTilt(root) == 15
root = TreeNode(1)
assert s.findTilt(root) == 0
root = TreeNode(1, TreeNode(2, TreeNode(3)), None)
assert s.findTilt(root) == 5
print("all tests passed")
run_tests()| Test | Why |
|---|---|
None | Empty tree |
[1, 2, 3] | Basic sample |
[4, 2, 9, 3, 5, None, 7] | Larger sample |
| Single node | Leaf tilt is zero |
| Left-skewed tree | Checks missing right subtree sums |