A clear explanation of finding the smallest subtree that contains all deepest nodes using bottom-up DFS.
Problem Restatement
We are given the root of a binary tree.
The depth of a node is its shortest distance from the root.
A deepest node is a node with the largest depth in the whole tree.
We need to return the smallest subtree that contains all deepest nodes.
Since a subtree is represented by its root node, we return the root of that smallest subtree.
Another way to say this: return the lowest common ancestor of all deepest nodes.
Input and Output
| Item | Meaning |
|---|---|
| Input | root, the root of a binary tree |
| Output | The root node of the smallest subtree containing all deepest nodes |
| Deepest node | A node with maximum depth in the whole tree |
| Subtree | A node and all of its descendants |
Function shape:
class Solution:
def subtreeWithAllDeepest(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
...Examples
Example 1:
root = [3, 5, 1, 6, 2, 0, 8, None, None, 7, 4]The tree is:
3
/ \
5 1
/ \ / \
6 2 0 8
/ \
7 4The deepest nodes are:
7 and 4Their lowest common ancestor is node 2.
So the returned subtree is rooted at node 2.
Example 2:
root = [1]The only node is also the deepest node.
So the answer is the root node 1.
Example 3:
root = [0, 1, 3, None, 2]The deepest node is 2.
If there is only one deepest node, the smallest subtree containing all deepest nodes is that node itself.
So the answer is node 2.
First Thought: Find Deepest Nodes, Then Find LCA
One possible approach has two stages.
First, traverse the tree to find the maximum depth and collect all deepest nodes.
Second, find the lowest common ancestor of those deepest nodes.
This works, but it requires separate logic for collecting nodes and then finding their common ancestor.
We can do both jobs in one DFS.
Key Insight
For each node, we want to know two things about its subtree:
- How deep the deepest node is inside this subtree
- Which node is the smallest subtree root that contains all deepest nodes inside this subtree
This suggests a bottom-up DFS.
At each node:
- If the left subtree is deeper, then all deepest nodes are in the left subtree.
- If the right subtree is deeper, then all deepest nodes are in the right subtree.
- If both subtrees have the same depth, then deepest nodes exist on both sides, or both sides are empty. In that case, the current node is the smallest subtree root for this subtree.
So the recursive function should return:
(node_answer, depth)Algorithm
Define a helper function:
dfs(node)It returns:
| Value | Meaning |
|---|---|
answer_node | Root of the smallest subtree containing deepest nodes in node’s subtree |
depth | Maximum depth below node, measured as subtree height |
Base case:
dfs(None) = (None, 0)For a real node:
- Compute results from the left subtree.
- Compute results from the right subtree.
- Compare their depths.
- If the left depth is larger, return the left answer.
- If the right depth is larger, return the right answer.
- If the depths are equal, return the current node.
- Add
1to the returned depth.
Correctness
For a None node, there are no deepest nodes, so returning (None, 0) is correct.
For a non-empty node, assume the DFS results for its left and right children are correct.
If the left subtree has greater depth than the right subtree, then every deepest node under the current node must be inside the left subtree. The smallest subtree containing those deepest nodes is exactly the left child’s returned answer.
If the right subtree has greater depth, the same reasoning applies symmetrically.
If the two depths are equal, then the deepest nodes under the current node are present at the same maximum depth on both sides, or the current node is a leaf. The smallest subtree containing all of them must include the current node, and no lower node can contain nodes from both sides. Therefore, the current node is the correct answer.
By induction, the helper returns the correct answer for every subtree. Calling it on the root returns the smallest subtree containing all deepest nodes in the entire tree.
Complexity
Let n be the number of nodes.
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | Recursion stack uses tree height h |
In the worst case, h = n for a skewed tree.
Implementation
from typing import Optional
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def subtreeWithAllDeepest(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
def dfs(node: Optional[TreeNode]) -> tuple[Optional[TreeNode], int]:
if node is None:
return None, 0
left_node, left_depth = dfs(node.left)
right_node, right_depth = dfs(node.right)
if left_depth > right_depth:
return left_node, left_depth + 1
if right_depth > left_depth:
return right_node, right_depth + 1
return node, left_depth + 1
answer, _ = dfs(root)
return answerCode Explanation
The helper returns both the candidate answer and the depth:
def dfs(node):For an empty subtree:
if node is None:
return None, 0There is no answer node, and its depth contribution is 0.
Then we solve both children:
left_node, left_depth = dfs(node.left)
right_node, right_depth = dfs(node.right)If the left side is deeper:
if left_depth > right_depth:
return left_node, left_depth + 1then all deepest nodes in this subtree live on the left side.
If the right side is deeper:
if right_depth > left_depth:
return right_node, right_depth + 1then all deepest nodes live on the right side.
If both sides have equal depth:
return node, left_depth + 1the current node is the meeting point for all deepest nodes in this subtree.
Finally:
answer, _ = dfs(root)
return answerreturns only the answer node.
Testing
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def test_subtree_with_all_deepest():
s = Solution()
root = TreeNode(3)
root.left = TreeNode(5)
root.right = TreeNode(1)
root.left.left = TreeNode(6)
root.left.right = TreeNode(2)
root.right.left = TreeNode(0)
root.right.right = TreeNode(8)
root.left.right.left = TreeNode(7)
root.left.right.right = TreeNode(4)
assert s.subtreeWithAllDeepest(root).val == 2
single = TreeNode(1)
assert s.subtreeWithAllDeepest(single).val == 1
root = TreeNode(0)
root.left = TreeNode(1)
root.right = TreeNode(3)
root.left.right = TreeNode(2)
assert s.subtreeWithAllDeepest(root).val == 2
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(3)
assert s.subtreeWithAllDeepest(root).val == 1
print("all tests passed")
test_subtree_with_all_deepest()Test meaning:
| Test | Why |
|---|---|
Standard tree with deepest nodes 7 and 4 | Answer is their lowest common ancestor |
| Single node | The root is the only deepest node |
| One deepest node | Answer is the deepest node itself |
| Deepest nodes on both sides | Answer is the root |