A clear explanation of counting uni-value subtrees using post-order DFS.
Problem Restatement
We are given the root of a binary tree.
Return the number of uni-value subtrees.
A uni-value subtree means every node in that subtree has the same value. An empty tree has 0 uni-value subtrees.
The constraints say the tree has between 0 and 1000 nodes, and each node value is between -1000 and 1000.
Input and Output
| Item | Meaning |
|---|---|
| Input | Root of a binary tree |
| Output | Number of uni-value subtrees |
| Empty tree | Returns 0 |
| Node value range | -1000 <= Node.val <= 1000 |
Example function shape:
def countUnivalSubtrees(root: Optional[TreeNode]) -> int:
...Examples
Example 1:
root = [5, 1, 5, 5, 5, None, 5]Tree:
5
/ \
1 5
/ \ \
5 5 5The uni-value subtrees are:
leaf 5
leaf 5
right child subtree 5 -> 5
leaf 5 under the right childSo the answer is:
4Example 2:
root = []There are no nodes, so there are no subtrees to count.
0Example 3:
root = [5, 5, 5, 5, 5, None, 5]Every subtree is uni-value because every node has value 5.
There are 6 nodes, so the answer is:
6First Thought
A direct solution is to check every subtree independently.
For each node, we can traverse the whole subtree rooted at that node and verify whether all values equal the root value.
This is correct, but it repeats work.
For example, when checking the root, we scan many descendants. Then when checking a child, we scan some of the same descendants again.
In the worst case, this can become O(n^2).
Key Insight
A subtree rooted at node is uni-value if all of these are true:
- The left subtree is uni-value.
- The right subtree is uni-value.
- If the left child exists,
left.val == node.val. - If the right child exists,
right.val == node.val.
This means the parent needs information from its children first.
So we should use post-order DFS:
left subtree -> right subtree -> current nodeThe DFS returns a boolean:
True means this subtree is uni-value
False means this subtree is not uni-valueWhenever DFS finds a uni-value subtree, it increments the answer.
Algorithm
Initialize:
answer = 0Define a DFS function:
dfs(node) -> boolFor each node:
- If
nodeisNone, returnTrue. - Recursively check the left subtree.
- Recursively check the right subtree.
- If either child subtree is not uni-value, return
False. - If the left child exists and has a different value, return
False. - If the right child exists and has a different value, return
False. - Otherwise, the current subtree is uni-value.
- Increment
answer. - Return
True.
After DFS finishes, return answer.
Correctness
The DFS processes children before their parent.
For a leaf node, both children are empty. Empty children do not violate the uni-value condition, so every leaf is counted as a uni-value subtree.
For an internal node, the subtree rooted at that node can be uni-value only if both child subtrees are uni-value and every existing child has the same value as the current node.
The algorithm checks exactly these conditions.
If any condition fails, the subtree contains at least two different values, so returning False is correct.
If all conditions pass, both children are either empty or roots of uni-value subtrees with the same value as the current node. Therefore, every node in the current subtree has the current node’s value, so the subtree is uni-value and should be counted.
Since DFS visits every node once and applies this logic to every subtree root, the final count is exactly the number of uni-value subtrees.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | Recursion stack depth is the tree height |
Here, n is the number of nodes and h is the height of the tree.
In the worst case, h = n for a skewed tree.
In a balanced tree, h = log n.
Implementation
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
from typing import Optional
class Solution:
def countUnivalSubtrees(self, root: Optional[TreeNode]) -> int:
answer = 0
def dfs(node: Optional[TreeNode]) -> bool:
nonlocal answer
if node is None:
return True
left_is_unival = dfs(node.left)
right_is_unival = dfs(node.right)
if not left_is_unival or not right_is_unival:
return False
if node.left is not None and node.left.val != node.val:
return False
if node.right is not None and node.right.val != node.val:
return False
answer += 1
return True
dfs(root)
return answerCode Explanation
The variable answer stores the number of uni-value subtrees found so far.
answer = 0The helper returns whether the subtree rooted at node is uni-value.
def dfs(node: Optional[TreeNode]) -> bool:An empty child is treated as valid because it does not introduce a different value.
if node is None:
return TrueWe recurse before checking the current node.
left_is_unival = dfs(node.left)
right_is_unival = dfs(node.right)If either child subtree already contains mixed values, the current subtree cannot be uni-value.
if not left_is_unival or not right_is_unival:
return FalseThen we check whether existing child roots match the current node value.
if node.left is not None and node.left.val != node.val:
return False
if node.right is not None and node.right.val != node.val:
return FalseIf all checks pass, the current subtree is uni-value.
answer += 1
return TrueFinally, after visiting the whole tree, return the count.
dfs(root)
return answerTesting
from collections import deque
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def build_tree(values):
if not values:
return None
root = TreeNode(values[0])
queue = deque([root])
i = 1
while queue and i < len(values):
node = queue.popleft()
if i < len(values) and values[i] is not None:
node.left = TreeNode(values[i])
queue.append(node.left)
i += 1
if i < len(values) and values[i] is not None:
node.right = TreeNode(values[i])
queue.append(node.right)
i += 1
return root
def run_tests():
s = Solution()
assert s.countUnivalSubtrees(
build_tree([5, 1, 5, 5, 5, None, 5])
) == 4
assert s.countUnivalSubtrees(
build_tree([])
) == 0
assert s.countUnivalSubtrees(
build_tree([5, 5, 5, 5, 5, None, 5])
) == 6
assert s.countUnivalSubtrees(
build_tree([1])
) == 1
assert s.countUnivalSubtrees(
build_tree([1, 1, 1, 1, 2])
) == 3
print("all tests passed")
run_tests()| Test | Why |
|---|---|
| Mixed tree | Checks standard example |
| Empty tree | Confirms 0 for no nodes |
| All same values | Every subtree should count |
| Single node | A leaf is always uni-value |
| One mismatching child | Confirms bad child value prevents parent counting |