A clear explanation of Maximum Depth of N-ary Tree using recursive depth-first search.
Problem Restatement
We are given the root of an N-ary tree.
An N-ary tree is a tree where each node may have any number of children.
We need to return the maximum depth of the tree.
The maximum depth is the number of nodes along the longest path from the root node down to the farthest leaf node. If the tree is empty, the depth is 0. The constraints say the total number of nodes is in the range [0, 10^4], and the tree depth is at most 1000.
Input and Output
| Item | Meaning |
|---|---|
| Input | Root node of an N-ary tree |
| Output | Maximum depth as an integer |
| Empty tree | Return 0 |
| Leaf node | Depth is 1 |
Example function shape:
def maxDepth(root: 'Node') -> int:
...Examples
Example 1:
root = [1, null, 3, 2, 4, null, 5, 6]This represents a tree where the longest root-to-leaf path has three nodes.
One such path is:
1 -> 3 -> 5So the answer is:
3Example 2:
root = []There is no tree.
So the answer is:
0Example 3:
root = [1]There is only one node.
The longest path from the root to a leaf contains one node.
So the answer is:
1First Thought: Traverse Every Path
To find the maximum depth, we need to know how far the deepest leaf is from the root.
A tree path begins at the root and follows child pointers downward.
So a natural solution is to explore the tree and compute the depth of each subtree.
For a node, its maximum depth is:
1 + maximum depth among its childrenThe 1 counts the current node.
If the node has no children, the maximum depth among children is 0, so the result is 1.
Key Insight
The tree is recursive.
Each child of the root is itself the root of a smaller N-ary tree.
So the problem can be solved with depth-first search.
For each node:
- Ask every child for its maximum depth.
- Take the largest child depth.
- Add
1for the current node.
The base case is an empty tree:
root is NoneIts depth is 0.
Algorithm
- If
rootisNone, return0. - Initialize
max_child_depth = 0. - For each child in
root.children:- Recursively compute the child depth.
- Update
max_child_depth.
- Return
1 + max_child_depth.
Correctness
If the tree is empty, there is no root-to-leaf path, so the maximum depth is 0. The algorithm returns 0 in this case.
Now consider a non-empty tree with root root.
Every path from root to a leaf must either:
- Stop at
root, ifrootis a leaf. - Continue into exactly one child subtree.
For each child, the recursive call returns the maximum depth of that child subtree. The deepest path from root must go through the child with the largest subtree depth.
So the maximum depth from root is:
1 + max_child_depthThe 1 counts root itself.
The algorithm computes exactly this value. Therefore, it returns the maximum number of nodes along any path from the root to a farthest leaf.
Complexity
Let n be the number of nodes in the tree.
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | The recursion stack stores one call per tree level |
Here h is the height of the tree. The problem states the depth is at most 1000.
Implementation
"""
# Definition for a Node.
class Node:
def __init__(self, val: int = None, children: list['Node'] = None):
self.val = val
self.children = children if children is not None else []
"""
class Solution:
def maxDepth(self, root: 'Node') -> int:
if root is None:
return 0
max_child_depth = 0
for child in root.children:
max_child_depth = max(max_child_depth, self.maxDepth(child))
return 1 + max_child_depthCode Explanation
The empty tree case is handled first:
if root is None:
return 0Then we track the deepest child subtree:
max_child_depth = 0For every child, we recursively compute its depth:
self.maxDepth(child)and keep the maximum:
max_child_depth = max(max_child_depth, self.maxDepth(child))Finally, we add one for the current node:
return 1 + max_child_depthIterative DFS Version
We can also avoid recursion by using an explicit stack.
Each stack entry stores:
(node, depth)Implementation:
class Solution:
def maxDepth(self, root: 'Node') -> int:
if root is None:
return 0
ans = 0
stack = [(root, 1)]
while stack:
node, depth = stack.pop()
ans = max(ans, depth)
for child in node.children:
stack.append((child, depth + 1))
return ansThis version performs the same traversal but stores pending nodes manually.
Testing
class Node:
def __init__(self, val=None, children=None):
self.val = val
self.children = children if children is not None else []
def run_tests():
s = Solution()
assert s.maxDepth(None) == 0
root = Node(1)
assert s.maxDepth(root) == 1
root = Node(1, [
Node(3, [Node(5), Node(6)]),
Node(2),
Node(4),
])
assert s.maxDepth(root) == 3
root = Node(1, [
Node(2),
Node(3, [
Node(6),
Node(7, [Node(11, [Node(14)])]),
]),
Node(4, [Node(8, [Node(12)])]),
Node(5, [
Node(9, [Node(13)]),
Node(10),
]),
])
assert s.maxDepth(root) == 5
print("all tests passed")
run_tests()| Test | Why |
|---|---|
None | Empty tree |
| Single root | Minimum non-empty tree |
Depth 3 tree | Matches the first sample shape |
Depth 5 tree | Matches the larger sample shape |