Skip to content

LeetCode 508: Most Frequent Subtree Sum

A clear explanation of finding the most frequent subtree sum in a binary tree using postorder DFS and a frequency map.

Problem Restatement

We are given the root of a binary tree.

For every node, define its subtree sum as the sum of all values in the subtree rooted at that node. This includes the node itself and all of its descendants.

We need to return the subtree sum value or values that appear most frequently.

If several subtree sums have the same highest frequency, return all of them in any order. The official problem asks for the most frequent subtree sum, with ties returned together.

Input and Output

ItemMeaning
InputThe root of a binary tree
OutputA list of integer subtree sums
GoalReturn every subtree sum with maximum frequency
OrderAny order is accepted

Function shape:

class Solution:
    def findFrequentTreeSum(self, root: Optional[TreeNode]) -> List[int]:
        ...

Examples

Example 1:

    5
   / \
  2  -3

The subtree sums are:

NodeSubtree sum
22
-3-3
55 + 2 + (-3) = 4

Each sum appears once.

So the answer can be:

[2, -3, 4]

Example 2:

    5
   / \
  2  -5

The subtree sums are:

NodeSubtree sum
22
-5-5
55 + 2 + (-5) = 2

The sum 2 appears twice.

So the answer is:

[2]

First Thought: Compute Every Subtree Sum

We need the subtree sum for every node.

For a node, the formula is:

subtree_sum = node.val + left_subtree_sum + right_subtree_sum

This means we need the children sums before the parent sum.

So the natural traversal is postorder DFS:

  1. Visit the left subtree.
  2. Visit the right subtree.
  3. Process the current node.

Key Insight

DFS can return information upward.

For each recursive call:

dfs(node)

return the sum of the subtree rooted at node.

Then the parent can use it directly.

At the same time, we count how often each subtree sum appears:

count[subtree_sum] += 1

After visiting the whole tree, we scan the frequency map and return every sum whose frequency equals the maximum frequency.

Algorithm

Create a frequency map:

count = defaultdict(int)

Define a DFS function.

For each node:

  1. If the node is None, return 0.
  2. Recursively compute the left subtree sum.
  3. Recursively compute the right subtree sum.
  4. Compute the current subtree sum.
  5. Increment its frequency.
  6. Return the current subtree sum.

After DFS finishes:

  1. Find the maximum frequency.
  2. Return all sums with that frequency.

Correctness

For a null child, the subtree sum is 0, which contributes nothing to its parent.

For a real node, the algorithm first computes the exact sum of the left subtree and the exact sum of the right subtree by recursion.

It then adds those two sums to the current node value. Therefore, the computed value is exactly the subtree sum rooted at that node.

The algorithm records this sum once for every node. Since every node is visited exactly once, the frequency map stores the exact number of times each subtree sum occurs.

Finally, the algorithm returns all sums whose frequency equals the largest frequency in the map. These are exactly the most frequent subtree sums.

Complexity

MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(n)The frequency map can store up to n distinct sums

The recursion stack uses O(h) space, where h is the tree height. In the worst case, h = n.

Implementation

from typing import Optional, List
from collections import defaultdict

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def findFrequentTreeSum(self, root: Optional[TreeNode]) -> List[int]:
        count = defaultdict(int)

        def dfs(node: Optional[TreeNode]) -> int:
            if not node:
                return 0

            left_sum = dfs(node.left)
            right_sum = dfs(node.right)

            subtree_sum = node.val + left_sum + right_sum
            count[subtree_sum] += 1

            return subtree_sum

        dfs(root)

        max_freq = max(count.values())

        return [
            subtree_sum
            for subtree_sum, freq in count.items()
            if freq == max_freq
        ]

Code Explanation

The frequency map stores how often each subtree sum appears:

count = defaultdict(int)

The DFS returns an integer:

def dfs(node: Optional[TreeNode]) -> int:

For an empty child, the sum is 0:

if not node:
    return 0

Compute child sums first:

left_sum = dfs(node.left)
right_sum = dfs(node.right)

Then compute the current subtree sum:

subtree_sum = node.val + left_sum + right_sum

Record it:

count[subtree_sum] += 1

Return it to the parent:

return subtree_sum

After traversal, find the highest frequency:

max_freq = max(count.values())

Then return every subtree sum with that frequency.

Testing

def sorted_result(root):
    return sorted(Solution().findFrequentTreeSum(root))

def test_find_frequent_tree_sum():
    # Tree:
    #     5
    #    / \
    #   2  -3
    root = TreeNode(5)
    root.left = TreeNode(2)
    root.right = TreeNode(-3)
    assert sorted_result(root) == [-3, 2, 4]

    # Tree:
    #     5
    #    / \
    #   2  -5
    root = TreeNode(5)
    root.left = TreeNode(2)
    root.right = TreeNode(-5)
    assert sorted_result(root) == [2]

    # Single node
    root = TreeNode(7)
    assert sorted_result(root) == [7]

    # Tree:
    #      0
    #     / \
    #    0   0
    root = TreeNode(0)
    root.left = TreeNode(0)
    root.right = TreeNode(0)
    assert sorted_result(root) == [0]

    print("all tests passed")

Test meaning:

TestWhy
[5, 2, -3] treeAll subtree sums tie
[5, 2, -5] treeOne sum appears twice
Single nodeMinimum tree case
All zero valuesRepeated identical subtree sums