A clear explanation of finding the most frequent subtree sum in a binary tree using postorder DFS and a frequency map.
Problem Restatement
We are given the root of a binary tree.
For every node, define its subtree sum as the sum of all values in the subtree rooted at that node. This includes the node itself and all of its descendants.
We need to return the subtree sum value or values that appear most frequently.
If several subtree sums have the same highest frequency, return all of them in any order. The official problem asks for the most frequent subtree sum, with ties returned together.
Input and Output
| Item | Meaning |
|---|---|
| Input | The root of a binary tree |
| Output | A list of integer subtree sums |
| Goal | Return every subtree sum with maximum frequency |
| Order | Any order is accepted |
Function shape:
class Solution:
def findFrequentTreeSum(self, root: Optional[TreeNode]) -> List[int]:
...Examples
Example 1:
5
/ \
2 -3The subtree sums are:
| Node | Subtree sum |
|---|---|
2 | 2 |
-3 | -3 |
5 | 5 + 2 + (-3) = 4 |
Each sum appears once.
So the answer can be:
[2, -3, 4]Example 2:
5
/ \
2 -5The subtree sums are:
| Node | Subtree sum |
|---|---|
2 | 2 |
-5 | -5 |
5 | 5 + 2 + (-5) = 2 |
The sum 2 appears twice.
So the answer is:
[2]First Thought: Compute Every Subtree Sum
We need the subtree sum for every node.
For a node, the formula is:
subtree_sum = node.val + left_subtree_sum + right_subtree_sumThis means we need the children sums before the parent sum.
So the natural traversal is postorder DFS:
- Visit the left subtree.
- Visit the right subtree.
- Process the current node.
Key Insight
DFS can return information upward.
For each recursive call:
dfs(node)return the sum of the subtree rooted at node.
Then the parent can use it directly.
At the same time, we count how often each subtree sum appears:
count[subtree_sum] += 1After visiting the whole tree, we scan the frequency map and return every sum whose frequency equals the maximum frequency.
Algorithm
Create a frequency map:
count = defaultdict(int)Define a DFS function.
For each node:
- If the node is
None, return0. - Recursively compute the left subtree sum.
- Recursively compute the right subtree sum.
- Compute the current subtree sum.
- Increment its frequency.
- Return the current subtree sum.
After DFS finishes:
- Find the maximum frequency.
- Return all sums with that frequency.
Correctness
For a null child, the subtree sum is 0, which contributes nothing to its parent.
For a real node, the algorithm first computes the exact sum of the left subtree and the exact sum of the right subtree by recursion.
It then adds those two sums to the current node value. Therefore, the computed value is exactly the subtree sum rooted at that node.
The algorithm records this sum once for every node. Since every node is visited exactly once, the frequency map stores the exact number of times each subtree sum occurs.
Finally, the algorithm returns all sums whose frequency equals the largest frequency in the map. These are exactly the most frequent subtree sums.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(n) | The frequency map can store up to n distinct sums |
The recursion stack uses O(h) space, where h is the tree height. In the worst case, h = n.
Implementation
from typing import Optional, List
from collections import defaultdict
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def findFrequentTreeSum(self, root: Optional[TreeNode]) -> List[int]:
count = defaultdict(int)
def dfs(node: Optional[TreeNode]) -> int:
if not node:
return 0
left_sum = dfs(node.left)
right_sum = dfs(node.right)
subtree_sum = node.val + left_sum + right_sum
count[subtree_sum] += 1
return subtree_sum
dfs(root)
max_freq = max(count.values())
return [
subtree_sum
for subtree_sum, freq in count.items()
if freq == max_freq
]Code Explanation
The frequency map stores how often each subtree sum appears:
count = defaultdict(int)The DFS returns an integer:
def dfs(node: Optional[TreeNode]) -> int:For an empty child, the sum is 0:
if not node:
return 0Compute child sums first:
left_sum = dfs(node.left)
right_sum = dfs(node.right)Then compute the current subtree sum:
subtree_sum = node.val + left_sum + right_sumRecord it:
count[subtree_sum] += 1Return it to the parent:
return subtree_sumAfter traversal, find the highest frequency:
max_freq = max(count.values())Then return every subtree sum with that frequency.
Testing
def sorted_result(root):
return sorted(Solution().findFrequentTreeSum(root))
def test_find_frequent_tree_sum():
# Tree:
# 5
# / \
# 2 -3
root = TreeNode(5)
root.left = TreeNode(2)
root.right = TreeNode(-3)
assert sorted_result(root) == [-3, 2, 4]
# Tree:
# 5
# / \
# 2 -5
root = TreeNode(5)
root.left = TreeNode(2)
root.right = TreeNode(-5)
assert sorted_result(root) == [2]
# Single node
root = TreeNode(7)
assert sorted_result(root) == [7]
# Tree:
# 0
# / \
# 0 0
root = TreeNode(0)
root.left = TreeNode(0)
root.right = TreeNode(0)
assert sorted_result(root) == [0]
print("all tests passed")Test meaning:
| Test | Why |
|---|---|
[5, 2, -3] tree | All subtree sums tie |
[5, 2, -5] tree | One sum appears twice |
| Single node | Minimum tree case |
| All zero values | Repeated identical subtree sums |