A clear explanation of counting nodes in a complete binary tree faster than visiting every node.
Problem Restatement
We are given the root of a complete binary tree.
We need to return the number of nodes in the tree.
A complete binary tree has this shape:
| Rule | Meaning |
|---|---|
| Every level except possibly the last | Completely filled |
| Last level | Filled from left to right |
| Missing nodes | Can only appear at the far right of the last level |
The official problem asks for an algorithm that runs in less than O(n) time, where n is the number of nodes.
Input and Output
| Item | Meaning |
|---|---|
| Input | Root of a complete binary tree |
| Output | Number of nodes |
| Tree property | Complete binary tree |
| Required goal | Faster than visiting every node |
Typical node definition:
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = rightExample function shape:
def countNodes(root: TreeNode) -> int:
...Examples
Example 1:
root = [1,2,3,4,5,6]Tree:
1
/ \
2 3
/ \ /
4 5 6There are 6 nodes.
Answer:
6Example 2:
root = []The tree is empty.
Answer:
0Example 3:
root = [1]There is one node.
Answer:
1First Thought: Count Every Node
For a normal binary tree, the simple solution is recursive DFS:
class Solution:
def countNodes(self, root: TreeNode) -> int:
if not root:
return 0
return 1 + self.countNodes(root.left) + self.countNodes(root.right)This is correct for any binary tree.
But it visits every node, so it takes:
O(n)The problem asks for better than O(n), so we need to use the complete tree property.
Key Insight
A complete binary tree often contains perfect subtrees.
A perfect binary tree has every level fully filled.
For a perfect tree with height h, the number of nodes is:
2^h - 1So if we can detect that a subtree is perfect, we can count all its nodes immediately without visiting them one by one.
Detecting a Perfect Subtree
For a complete tree:
- follow the left pointers to get the left height
- follow the right pointers to get the right height
If these heights are equal, the subtree is perfect.
Example:
1
/ \
2 3
/ \ / \
4 5 6 7Left height:
1 -> 2 -> 4Right height:
1 -> 3 -> 7Both heights are 3.
So the tree is perfect and has:
2^3 - 1 = 7nodes.
Algorithm
- If
rootisNone, return0. - Compute the left height by walking down
.left. - Compute the right height by walking down
.right. - If both heights are equal:
- return
2^height - 1
- return
- Otherwise:
- recursively count the left subtree
- recursively count the right subtree
- add
1for the root
Walkthrough
Use:
root = [1,2,3,4,5,6]Tree:
1
/ \
2 3
/ \ /
4 5 6At root 1:
Left height:
1 -> 2 -> 4
height = 3Right height:
1 -> 3
height = 2They differ, so the tree is not perfect.
Count left subtree rooted at 2:
2
/ \
4 5Left height and right height are both 2, so this subtree is perfect.
Count:
2^2 - 1 = 3Count right subtree rooted at 3:
3
/
6Left height is 2.
Right height is 1.
They differ.
Count left child 6:
single node = 1Right child is empty:
0So subtree rooted at 3 has:
1 + 1 + 0 = 2Total:
1 + 3 + 2 = 6Correctness
If the root is empty, the tree has zero nodes, so returning 0 is correct.
For a non-empty complete subtree, the algorithm computes the height of its leftmost path and rightmost path.
If the heights are equal, then the subtree is perfect. A perfect binary tree of height h contains exactly 2^h - 1 nodes, so returning that value is correct.
If the heights differ, the subtree is not perfect. The total number of nodes is the root plus the nodes in the left subtree plus the nodes in the right subtree. The algorithm recursively computes those two counts.
Each recursive call receives a subtree of the original complete tree, and subtrees of a complete tree still have enough structure for the same perfect-subtree test to apply.
Therefore the algorithm counts every imperfect part recursively and counts every perfect part by formula, giving the exact total number of nodes.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(log^2 n) | At each recursive level, height computation costs O(log n), and there are O(log n) levels |
| Space | O(log n) | Recursion depth is the tree height |
This satisfies the requirement to run in less than O(n) time.
Implementation
class Solution:
def countNodes(self, root: TreeNode) -> int:
if not root:
return 0
left_height = self.get_left_height(root)
right_height = self.get_right_height(root)
if left_height == right_height:
return (1 << left_height) - 1
return 1 + self.countNodes(root.left) + self.countNodes(root.right)
def get_left_height(self, node: TreeNode) -> int:
height = 0
while node:
height += 1
node = node.left
return height
def get_right_height(self, node: TreeNode) -> int:
height = 0
while node:
height += 1
node = node.right
return heightCode Explanation
Handle the empty tree:
if not root:
return 0Compute the leftmost height:
left_height = self.get_left_height(root)Compute the rightmost height:
right_height = self.get_right_height(root)If they match, this subtree is perfect:
if left_height == right_height:Count it directly:
return (1 << left_height) - 11 << left_height means 2^left_height.
Otherwise, count recursively:
return 1 + self.countNodes(root.left) + self.countNodes(root.right)The helper functions walk one side of the tree:
while node:
height += 1
node = node.leftand:
while node:
height += 1
node = node.rightAlternative: Binary Search on the Last Level
A complete tree with height h has all levels above the last fully filled.
So the number of nodes above the last level is:
2^(h - 1) - 1Then we can binary search how many nodes exist on the last level.
This also runs in:
O(log^2 n)The perfect-subtree recursion is usually easier to write and explain.
Testing
def build_tree(values):
if not values:
return None
nodes = [None if value is None else TreeNode(value) for value in values]
for i, node in enumerate(nodes):
if node is None:
continue
left = 2 * i + 1
right = 2 * i + 2
if left < len(nodes):
node.left = nodes[left]
if right < len(nodes):
node.right = nodes[right]
return nodes[0]
def run_tests():
s = Solution()
root = build_tree([1, 2, 3, 4, 5, 6])
assert s.countNodes(root) == 6
root = build_tree([])
assert s.countNodes(root) == 0
root = build_tree([1])
assert s.countNodes(root) == 1
root = build_tree([1, 2, 3, 4, 5, 6, 7])
assert s.countNodes(root) == 7
root = build_tree([1, 2, 3, 4])
assert s.countNodes(root) == 4
print("all tests passed")
run_tests()| Test | Why |
|---|---|
[1,2,3,4,5,6] | Standard complete tree |
[] | Empty tree |
[1] | Single node |
| Perfect tree with 7 nodes | Formula path |
[1,2,3,4] | Last level partially filled |