A clear explanation of Minimum Height Trees using leaf trimming to find the center of a tree.
Problem Restatement
We are given a tree with n nodes labeled from 0 to n - 1.
The tree is given as an undirected edge list edges, where:
edges[i] = [a, b]means node a and node b are connected.
We may choose any node as the root. The height of the rooted tree is the number of edges on the longest downward path from the root to a leaf.
Return all root labels that produce a tree with minimum possible height. The answer can be returned in any order. The input is guaranteed to be a tree, with edges.length == n - 1 and 1 <= n <= 2 * 10^4.
Input and Output
| Item | Meaning |
|---|---|
| Input | n and an undirected edge list edges |
| Output | All root labels that give minimum height |
| Nodes | Labeled from 0 to n - 1 |
| Graph type | Connected tree with no cycles |
| Height | Longest path from root to any leaf |
Function shape:
def findMinHeightTrees(n: int, edges: list[list[int]]) -> list[int]:
...Examples
Example 1:
n = 4
edges = [[1, 0], [1, 2], [1, 3]]The tree is centered at node 1.
If we root the tree at 1, all other nodes are one edge away.
Output:
[1]Example 2:
n = 6
edges = [[3, 0], [3, 1], [3, 2], [3, 4], [5, 4]]The two best roots are 3 and 4.
Output:
[3, 4]First Thought: Try Every Root
A direct solution is:
- Pick each node as root.
- Run BFS or DFS from that node.
- Measure the maximum distance to any other node.
- Keep the nodes with the smallest height.
This is easy to understand.
But it is too slow.
For each root, BFS costs O(n). Trying all n roots costs:
O(n^2)Since n can be 20000, this approach can time out.
Key Insight
The best root must be at the center of the tree.
Think of a long path:
0 - 1 - 2 - 3 - 4Rooting at an end gives height 4.
Rooting at the middle gives height 2.
So the minimum-height root is the middle node.
For a general tree, the same idea still works. The minimum-height roots are the center nodes of the tree.
A tree can have either:
| Number of Centers | Meaning |
|---|---|
1 | One exact center |
2 | Two middle nodes on the longest path |
Leaf Trimming Idea
Leaves are nodes with degree 1.
A leaf can never be the center of a tree with more than two nodes. It is always on the outside.
So we repeatedly remove all current leaves.
This peels the tree from the outside toward the center.
Example:
0 - 1 - 2 - 3 - 4Initial leaves:
0, 4Remove them:
1 - 2 - 3New leaves:
1, 3Remove them:
2The remaining node is the center.
So the answer is:
[2]For an even-length center case:
0 - 1 - 2 - 3Remove leaves 0 and 3.
Remaining:
1 - 2The answer is:
[1, 2]Algorithm
Handle the single-node case first:
if n == 1:
return [0]Then:
- Build an adjacency list.
- Compute the degree of every node.
- Put all leaves into a queue.
- Keep a count of remaining nodes.
- While more than two nodes remain:
- Remove the current layer of leaves.
- Decrease the degree of their neighbors.
- Any neighbor whose degree becomes
1becomes a new leaf.
- Return the nodes left in the queue.
Those remaining nodes are the tree centers.
Correctness
Every removed node is a leaf at the time it is removed.
A leaf cannot be the unique best root while the tree has more than two nodes. If we move the root one step from a leaf toward the inside of the tree, the farthest distance to the rest of the tree does not increase, and the distance to the opposite side decreases. So the leaf is never better than its neighbor.
Removing all leaves also preserves the center of the tree. The outside layer contributes equally to every possible root inside the remaining tree. Peeling that layer shortens all longest paths from both ends without changing their middle.
The algorithm repeats this process until only one or two nodes remain.
A tree center is either one node or two adjacent nodes. These remaining nodes are exactly the middle of every longest path, so rooting the tree at any of them minimizes the maximum distance to a leaf.
Therefore, the returned nodes are exactly the roots of all minimum height trees.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is removed once and each edge is processed at most twice |
| Space | O(n) | Adjacency list, degree array, and queue |
Implementation
from collections import deque
class Solution:
def findMinHeightTrees(
self,
n: int,
edges: list[list[int]],
) -> list[int]:
if n == 1:
return [0]
graph = [[] for _ in range(n)]
degree = [0] * n
for a, b in edges:
graph[a].append(b)
graph[b].append(a)
degree[a] += 1
degree[b] += 1
leaves = deque()
for node in range(n):
if degree[node] == 1:
leaves.append(node)
remaining = n
while remaining > 2:
layer_size = len(leaves)
remaining -= layer_size
for _ in range(layer_size):
leaf = leaves.popleft()
for neighbor in graph[leaf]:
degree[neighbor] -= 1
if degree[neighbor] == 1:
leaves.append(neighbor)
return list(leaves)Code Explanation
The single-node case needs special handling.
if n == 1:
return [0]A single node has height 0, so node 0 is the only answer.
Build the graph:
graph = [[] for _ in range(n)]
degree = [0] * nFor each edge, add both directions because the tree is undirected.
for a, b in edges:
graph[a].append(b)
graph[b].append(a)
degree[a] += 1
degree[b] += 1Now find the first layer of leaves.
for node in range(n):
if degree[node] == 1:
leaves.append(node)A node with degree 1 is on the outside of the tree.
We track how many nodes remain.
remaining = nThen repeatedly remove complete leaf layers.
while remaining > 2:We process exactly the current leaves, not leaves created during the same loop.
layer_size = len(leaves)
remaining -= layer_sizeFor each removed leaf, reduce the degree of its neighbors.
for neighbor in graph[leaf]:
degree[neighbor] -= 1If a neighbor’s degree becomes 1, it becomes a leaf for the next layer.
if degree[neighbor] == 1:
leaves.append(neighbor)When at most two nodes remain, they are the center nodes.
return list(leaves)Testing
def normalize(ans):
return sorted(ans)
def run_tests():
s = Solution()
assert normalize(s.findMinHeightTrees(
4,
[[1, 0], [1, 2], [1, 3]],
)) == [1]
assert normalize(s.findMinHeightTrees(
6,
[[3, 0], [3, 1], [3, 2], [3, 4], [5, 4]],
)) == [3, 4]
assert normalize(s.findMinHeightTrees(
1,
[],
)) == [0]
assert normalize(s.findMinHeightTrees(
2,
[[0, 1]],
)) == [0, 1]
assert normalize(s.findMinHeightTrees(
5,
[[0, 1], [1, 2], [2, 3], [3, 4]],
)) == [2]
assert normalize(s.findMinHeightTrees(
4,
[[0, 1], [1, 2], [2, 3]],
)) == [1, 2]
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
| Star-shaped tree | One obvious center |
| Two-center example | Confirms two roots can be returned |
| Single node | Special case |
| Two nodes | Both nodes are valid roots |
| Odd-length chain | One center |
| Even-length chain | Two centers |