A clear explanation of grouping binary tree nodes by the round in which they become leaves using postorder DFS.
Problem Restatement
We are given the root of a binary tree.
We need to collect the tree’s nodes as if we repeatedly do this:
- Collect all current leaf nodes.
- Remove those leaf nodes.
- Repeat until the tree is empty.
Return a list of lists.
Each inner list contains the node values removed in the same round.
The order inside each round does not matter.
The official example is:
root = [1, 2, 3, 4, 5]Output:
[[4, 5, 3], [2], [1]]Other orders within the first group, such as [[3, 4, 5], [2], [1]], are also accepted. The constraints say the tree has between 1 and 100 nodes, and node values are between -100 and 100.
Input and Output
| Item | Meaning |
|---|---|
| Input | Root of a binary tree |
| Output | List of leaf-removal rounds |
| Leaf | A node with no children |
| Order inside one round | Does not matter |
| Main goal | Group nodes by when they become leaves |
Example function shape:
def findLeaves(root: Optional[TreeNode]) -> list[list[int]]:
...Examples
Example 1:
1
/ \
2 3
/ \
4 5First round:
[4, 5, 3]These are the original leaves.
After removing them, the tree becomes:
1
/
2Second round:
[2]After removing 2, the tree becomes:
1Third round:
[1]So the answer is:
[[4, 5, 3], [2], [1]]Example 2:
root = [1]The root is already a leaf.
Answer:
[[1]]First Thought: Simulate Removal
A direct approach is to repeatedly scan the tree.
In each round:
- Find all current leaves.
- Add their values to the answer.
- Remove them from their parent.
- Repeat until the root is removed.
This matches the problem statement closely.
But it is inefficient because we may scan the same nodes many times.
For a skewed tree with n nodes, each round removes only one node. Repeated scanning can cost:
O(n^2)We can do better with one DFS traversal.
Key Insight
A node is removed based on its height from the bottom.
Define height like this:
| Node type | Height |
|---|---|
None child | -1 |
| Leaf node | 0 |
| Parent of a leaf | 1 |
| Root above height-1 child | 2 |
So:
height(node) = 1 + max(height(left), height(right))Leaves have height 0 because both children have height -1.
Nodes with the same height are removed in the same round.
For the tree:
1
/ \
2 3
/ \
4 5The heights are:
| Node | Height | Removal round |
|---|---|---|
4 | 0 | first |
5 | 0 | first |
3 | 0 | first |
2 | 1 | second |
1 | 2 | third |
So we only need to compute each node’s height and append its value to answer[height].
Algorithm
Use postorder DFS.
For each node:
- Recursively compute the height of the left subtree.
- Recursively compute the height of the right subtree.
- Compute this node’s height:
height = 1 + max(left_height, right_height)- Ensure
answerhas a list for this height. - Append
node.valtoanswer[height]. - Return
height.
Postorder traversal is necessary because we need children’s heights before computing the current node’s height.
Correctness
A leaf has no children. The DFS returns -1 for missing children, so a leaf receives height:
1 + max(-1, -1) = 0Therefore all original leaves are placed into answer[0], which is the first removal round.
For any non-leaf node, it can become a leaf only after all of its children have been removed. If its tallest child has height h, that child is removed in round h. The current node becomes removable one round later, so its removal round is h + 1.
The DFS computes exactly:
1 + max(left_height, right_height)which matches that removal round.
By induction from leaves upward, every node is placed into the list corresponding to the round in which it becomes a leaf.
Therefore, the algorithm returns exactly the required groups.
Complexity
Let n be the number of nodes.
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(n) | Output stores all node values |
| Recursion stack | O(h) | h is the tree height |
For a balanced tree, the recursion stack is O(log n).
For a skewed tree, it can be O(n).
Implementation
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def findLeaves(self, root: Optional[TreeNode]) -> list[list[int]]:
answer = []
def dfs(node: Optional[TreeNode]) -> int:
if node is None:
return -1
left_height = dfs(node.left)
right_height = dfs(node.right)
height = 1 + max(left_height, right_height)
if height == len(answer):
answer.append([])
answer[height].append(node.val)
return height
dfs(root)
return answerCode Explanation
The answer list stores groups by height:
answer = []The DFS returns the height from the bottom.
For a missing child, we return -1:
if node is None:
return -1This makes a leaf height 0.
We compute child heights first:
left_height = dfs(node.left)
right_height = dfs(node.right)Then compute the current height:
height = 1 + max(left_height, right_height)If this is the first node at that height, create a new group:
if height == len(answer):
answer.append([])Then append the node value:
answer[height].append(node.val)Finally, return the height to the parent:
return heightTesting
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def run_tests():
s = Solution()
root = TreeNode(
1,
TreeNode(2, TreeNode(4), TreeNode(5)),
TreeNode(3),
)
assert s.findLeaves(root) == [[4, 5, 3], [2], [1]]
root = TreeNode(1)
assert s.findLeaves(root) == [[1]]
root = TreeNode(1, TreeNode(2, TreeNode(3)))
assert s.findLeaves(root) == [[3], [2], [1]]
root = TreeNode(
1,
TreeNode(2),
TreeNode(3, None, TreeNode(4)),
)
assert s.findLeaves(root) == [[2, 4], [3], [1]]
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
| Standard tree | Checks multiple leaves in first round |
| Single node | Root is already a leaf |
| Left-skewed tree | One node removed per round |
| Mixed shape | Checks different subtree heights |