Skip to content

LeetCode 366: Find Leaves of Binary Tree

A clear explanation of grouping binary tree nodes by the round in which they become leaves using postorder DFS.

Problem Restatement

We are given the root of a binary tree.

We need to collect the tree’s nodes as if we repeatedly do this:

  1. Collect all current leaf nodes.
  2. Remove those leaf nodes.
  3. Repeat until the tree is empty.

Return a list of lists.

Each inner list contains the node values removed in the same round.

The order inside each round does not matter.

The official example is:

root = [1, 2, 3, 4, 5]

Output:

[[4, 5, 3], [2], [1]]

Other orders within the first group, such as [[3, 4, 5], [2], [1]], are also accepted. The constraints say the tree has between 1 and 100 nodes, and node values are between -100 and 100.

Input and Output

ItemMeaning
InputRoot of a binary tree
OutputList of leaf-removal rounds
LeafA node with no children
Order inside one roundDoes not matter
Main goalGroup nodes by when they become leaves

Example function shape:

def findLeaves(root: Optional[TreeNode]) -> list[list[int]]:
    ...

Examples

Example 1:

      1
     / \
    2   3
   / \
  4   5

First round:

[4, 5, 3]

These are the original leaves.

After removing them, the tree becomes:

  1
 /
2

Second round:

[2]

After removing 2, the tree becomes:

1

Third round:

[1]

So the answer is:

[[4, 5, 3], [2], [1]]

Example 2:

root = [1]

The root is already a leaf.

Answer:

[[1]]

First Thought: Simulate Removal

A direct approach is to repeatedly scan the tree.

In each round:

  1. Find all current leaves.
  2. Add their values to the answer.
  3. Remove them from their parent.
  4. Repeat until the root is removed.

This matches the problem statement closely.

But it is inefficient because we may scan the same nodes many times.

For a skewed tree with n nodes, each round removes only one node. Repeated scanning can cost:

O(n^2)

We can do better with one DFS traversal.

Key Insight

A node is removed based on its height from the bottom.

Define height like this:

Node typeHeight
None child-1
Leaf node0
Parent of a leaf1
Root above height-1 child2

So:

height(node) = 1 + max(height(left), height(right))

Leaves have height 0 because both children have height -1.

Nodes with the same height are removed in the same round.

For the tree:

      1
     / \
    2   3
   / \
  4   5

The heights are:

NodeHeightRemoval round
40first
50first
30first
21second
12third

So we only need to compute each node’s height and append its value to answer[height].

Algorithm

Use postorder DFS.

For each node:

  1. Recursively compute the height of the left subtree.
  2. Recursively compute the height of the right subtree.
  3. Compute this node’s height:
height = 1 + max(left_height, right_height)
  1. Ensure answer has a list for this height.
  2. Append node.val to answer[height].
  3. Return height.

Postorder traversal is necessary because we need children’s heights before computing the current node’s height.

Correctness

A leaf has no children. The DFS returns -1 for missing children, so a leaf receives height:

1 + max(-1, -1) = 0

Therefore all original leaves are placed into answer[0], which is the first removal round.

For any non-leaf node, it can become a leaf only after all of its children have been removed. If its tallest child has height h, that child is removed in round h. The current node becomes removable one round later, so its removal round is h + 1.

The DFS computes exactly:

1 + max(left_height, right_height)

which matches that removal round.

By induction from leaves upward, every node is placed into the list corresponding to the round in which it becomes a leaf.

Therefore, the algorithm returns exactly the required groups.

Complexity

Let n be the number of nodes.

MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(n)Output stores all node values
Recursion stackO(h)h is the tree height

For a balanced tree, the recursion stack is O(log n).

For a skewed tree, it can be O(n).

Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def findLeaves(self, root: Optional[TreeNode]) -> list[list[int]]:
        answer = []

        def dfs(node: Optional[TreeNode]) -> int:
            if node is None:
                return -1

            left_height = dfs(node.left)
            right_height = dfs(node.right)

            height = 1 + max(left_height, right_height)

            if height == len(answer):
                answer.append([])

            answer[height].append(node.val)

            return height

        dfs(root)
        return answer

Code Explanation

The answer list stores groups by height:

answer = []

The DFS returns the height from the bottom.

For a missing child, we return -1:

if node is None:
    return -1

This makes a leaf height 0.

We compute child heights first:

left_height = dfs(node.left)
right_height = dfs(node.right)

Then compute the current height:

height = 1 + max(left_height, right_height)

If this is the first node at that height, create a new group:

if height == len(answer):
    answer.append([])

Then append the node value:

answer[height].append(node.val)

Finally, return the height to the parent:

return height

Testing

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def run_tests():
    s = Solution()

    root = TreeNode(
        1,
        TreeNode(2, TreeNode(4), TreeNode(5)),
        TreeNode(3),
    )
    assert s.findLeaves(root) == [[4, 5, 3], [2], [1]]

    root = TreeNode(1)
    assert s.findLeaves(root) == [[1]]

    root = TreeNode(1, TreeNode(2, TreeNode(3)))
    assert s.findLeaves(root) == [[3], [2], [1]]

    root = TreeNode(
        1,
        TreeNode(2),
        TreeNode(3, None, TreeNode(4)),
    )
    assert s.findLeaves(root) == [[2, 4], [3], [1]]

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
Standard treeChecks multiple leaves in first round
Single nodeRoot is already a leaf
Left-skewed treeOne node removed per round
Mixed shapeChecks different subtree heights