# LeetCode 366: Find Leaves of Binary Tree

## Problem Restatement

We are given the root of a binary tree.

We need to collect the tree's nodes as if we repeatedly do this:

1. Collect all current leaf nodes.
2. Remove those leaf nodes.
3. Repeat until the tree is empty.

Return a list of lists.

Each inner list contains the node values removed in the same round.

The order inside each round does not matter.

The official example is:

```python
root = [1, 2, 3, 4, 5]
```

Output:

```python
[[4, 5, 3], [2], [1]]
```

Other orders within the first group, such as `[[3, 4, 5], [2], [1]]`, are also accepted. The constraints say the tree has between `1` and `100` nodes, and node values are between `-100` and `100`.

## Input and Output

| Item | Meaning |
|---|---|
| Input | Root of a binary tree |
| Output | List of leaf-removal rounds |
| Leaf | A node with no children |
| Order inside one round | Does not matter |
| Main goal | Group nodes by when they become leaves |

Example function shape:

```python
def findLeaves(root: Optional[TreeNode]) -> list[list[int]]:
    ...
```

## Examples

Example 1:

```text
      1
     / \
    2   3
   / \
  4   5
```

First round:

```python
[4, 5, 3]
```

These are the original leaves.

After removing them, the tree becomes:

```text
  1
 /
2
```

Second round:

```python
[2]
```

After removing `2`, the tree becomes:

```text
1
```

Third round:

```python
[1]
```

So the answer is:

```python
[[4, 5, 3], [2], [1]]
```

Example 2:

```python
root = [1]
```

The root is already a leaf.

Answer:

```python
[[1]]
```

## First Thought: Simulate Removal

A direct approach is to repeatedly scan the tree.

In each round:

1. Find all current leaves.
2. Add their values to the answer.
3. Remove them from their parent.
4. Repeat until the root is removed.

This matches the problem statement closely.

But it is inefficient because we may scan the same nodes many times.

For a skewed tree with `n` nodes, each round removes only one node. Repeated scanning can cost:

```python
O(n^2)
```

We can do better with one DFS traversal.

## Key Insight

A node is removed based on its height from the bottom.

Define height like this:

| Node type | Height |
|---|---:|
| `None` child | `-1` |
| Leaf node | `0` |
| Parent of a leaf | `1` |
| Root above height-1 child | `2` |

So:

```text
height(node) = 1 + max(height(left), height(right))
```

Leaves have height `0` because both children have height `-1`.

Nodes with the same height are removed in the same round.

For the tree:

```text
      1
     / \
    2   3
   / \
  4   5
```

The heights are:

| Node | Height | Removal round |
|---:|---:|---:|
| `4` | `0` | first |
| `5` | `0` | first |
| `3` | `0` | first |
| `2` | `1` | second |
| `1` | `2` | third |

So we only need to compute each node's height and append its value to `answer[height]`.

## Algorithm

Use postorder DFS.

For each node:

1. Recursively compute the height of the left subtree.
2. Recursively compute the height of the right subtree.
3. Compute this node's height:

```python
height = 1 + max(left_height, right_height)
```

4. Ensure `answer` has a list for this height.
5. Append `node.val` to `answer[height]`.
6. Return `height`.

Postorder traversal is necessary because we need children's heights before computing the current node's height.

## Correctness

A leaf has no children. The DFS returns `-1` for missing children, so a leaf receives height:

```python
1 + max(-1, -1) = 0
```

Therefore all original leaves are placed into `answer[0]`, which is the first removal round.

For any non-leaf node, it can become a leaf only after all of its children have been removed. If its tallest child has height `h`, that child is removed in round `h`. The current node becomes removable one round later, so its removal round is `h + 1`.

The DFS computes exactly:

```python
1 + max(left_height, right_height)
```

which matches that removal round.

By induction from leaves upward, every node is placed into the list corresponding to the round in which it becomes a leaf.

Therefore, the algorithm returns exactly the required groups.

## Complexity

Let `n` be the number of nodes.

| Metric | Value | Why |
|---|---:|---|
| Time | `O(n)` | Each node is visited once |
| Space | `O(n)` | Output stores all node values |
| Recursion stack | `O(h)` | `h` is the tree height |

For a balanced tree, the recursion stack is `O(log n)`.

For a skewed tree, it can be `O(n)`.

## Implementation

```python
# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def findLeaves(self, root: Optional[TreeNode]) -> list[list[int]]:
        answer = []

        def dfs(node: Optional[TreeNode]) -> int:
            if node is None:
                return -1

            left_height = dfs(node.left)
            right_height = dfs(node.right)

            height = 1 + max(left_height, right_height)

            if height == len(answer):
                answer.append([])

            answer[height].append(node.val)

            return height

        dfs(root)
        return answer
```

## Code Explanation

The answer list stores groups by height:

```python
answer = []
```

The DFS returns the height from the bottom.

For a missing child, we return `-1`:

```python
if node is None:
    return -1
```

This makes a leaf height `0`.

We compute child heights first:

```python
left_height = dfs(node.left)
right_height = dfs(node.right)
```

Then compute the current height:

```python
height = 1 + max(left_height, right_height)
```

If this is the first node at that height, create a new group:

```python
if height == len(answer):
    answer.append([])
```

Then append the node value:

```python
answer[height].append(node.val)
```

Finally, return the height to the parent:

```python
return height
```

## Testing

```python
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def run_tests():
    s = Solution()

    root = TreeNode(
        1,
        TreeNode(2, TreeNode(4), TreeNode(5)),
        TreeNode(3),
    )
    assert s.findLeaves(root) == [[4, 5, 3], [2], [1]]

    root = TreeNode(1)
    assert s.findLeaves(root) == [[1]]

    root = TreeNode(1, TreeNode(2, TreeNode(3)))
    assert s.findLeaves(root) == [[3], [2], [1]]

    root = TreeNode(
        1,
        TreeNode(2),
        TreeNode(3, None, TreeNode(4)),
    )
    assert s.findLeaves(root) == [[2, 4], [3], [1]]

    print("all tests passed")

run_tests()
```

Test meaning:

| Test | Why |
|---|---|
| Standard tree | Checks multiple leaves in first round |
| Single node | Root is already a leaf |
| Left-skewed tree | One node removed per round |
| Mixed shape | Checks different subtree heights |

