A clear explanation of checking whether every node in a binary tree has the same value.
Problem Restatement
We are given the root of a binary tree.
A binary tree is univalued if every node in the tree has the same value.
Return true if every node has the same value. Otherwise, return false.
The official constraints say the tree has between 1 and 100 nodes, and each node value is in the range [0, 99].
Input and Output
| Item | Meaning |
|---|---|
| Input | root, the root of a binary tree |
| Output | true if all nodes have the same value, otherwise false |
| Condition | Every node value must equal the root value |
Example function shape:
def isUnivalTree(root: Optional[TreeNode]) -> bool:
...Examples
Example 1:
root = [1, 1, 1, 1, 1, None, 1]Output:
TrueEvery node has value 1.
Example 2:
root = [2, 2, 2, 5, 2]Output:
FalseMost nodes have value 2, but one node has value 5.
So the tree is not univalued.
First Thought: Traverse Every Node
To know whether all nodes have the same value, we need to inspect the tree.
The natural reference value is:
root.valThen every other node must match this value.
We can use DFS because the tree is recursive:
- Check the current node.
- Check the left subtree.
- Check the right subtree.
Key Insight
A tree is univalued if:
- The current node has the target value.
- The left subtree is univalued with the same target value.
- The right subtree is univalued with the same target value.
A missing child does not violate the condition, so None should return True.
Algorithm
- Store the root value as
target. - Run DFS from the root.
- For each node:
- If the node is
None, returnTrue. - If
node.val != target, returnFalse. - Recursively check the left and right children.
- If the node is
- Return the result of DFS.
Correctness
The algorithm compares every real node with the root value.
If any node has a different value, the tree cannot be univalued, so returning False is correct.
If a subtree is empty, it contains no node that can violate the condition, so returning True for None is correct.
For a non-empty subtree, the algorithm returns True only when the current node matches the target and both child subtrees also satisfy the same condition.
Therefore, the algorithm returns True exactly when every node in the tree has the same value as the root. That is exactly the definition of a univalued binary tree.
Complexity
Let n be the number of nodes.
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited at most once |
| Space | O(h) | The recursion stack depends on tree height |
In the worst case, h = n for a skewed tree.
For a balanced tree, h = log n.
Implementation
from typing import Optional
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def isUnivalTree(self, root: Optional[TreeNode]) -> bool:
target = root.val
def dfs(node: Optional[TreeNode]) -> bool:
if node is None:
return True
if node.val != target:
return False
return dfs(node.left) and dfs(node.right)
return dfs(root)Code Explanation
We store the value every node must match:
target = root.valThe constraints guarantee at least one node, so root is not None.
The DFS base case handles missing children:
if node is None:
return TrueThen we reject mismatched values:
if node.val != target:
return FalseIf the current node is valid, both subtrees must also be valid:
return dfs(node.left) and dfs(node.right)Finally:
return dfs(root)checks the whole tree.
Testing
from collections import deque
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def build_tree(values):
if not values:
return None
root = TreeNode(values[0])
queue = deque([root])
i = 1
while queue and i < len(values):
node = queue.popleft()
if i < len(values) and values[i] is not None:
node.left = TreeNode(values[i])
queue.append(node.left)
i += 1
if i < len(values) and values[i] is not None:
node.right = TreeNode(values[i])
queue.append(node.right)
i += 1
return root
def run_tests():
s = Solution()
root = build_tree([1, 1, 1, 1, 1, None, 1])
assert s.isUnivalTree(root) is True
root = build_tree([2, 2, 2, 5, 2])
assert s.isUnivalTree(root) is False
root = build_tree([7])
assert s.isUnivalTree(root) is True
root = build_tree([0, 0, 0, 0])
assert s.isUnivalTree(root) is True
root = build_tree([1, 1, 1, None, None, 1, 2])
assert s.isUnivalTree(root) is False
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
[1,1,1,1,1,None,1] | All nodes match |
[2,2,2,5,2] | One node differs |
[7] | Single-node tree |
[0,0,0,0] | Checks value 0 |
[1,1,1,None,None,1,2] | Mismatch deep in the tree |