A clear explanation of finding the maximum depth of a binary tree using recursive depth-first search.
Problem Restatement
We are given the root of a binary tree.
We need to return its maximum depth.
The maximum depth is the number of nodes along the longest path from the root node down to the farthest leaf node. The official statement uses this same definition.
For this tree:
3
/ \
9 20
/ \
15 7The longest root-to-leaf path is:
3 -> 20 -> 15or:
3 -> 20 -> 7Both paths contain 3 nodes.
So the answer is:
3Input and Output
| Item | Meaning |
|---|---|
| Input | root, the root node of a binary tree |
| Output | An integer depth |
| Empty tree | Depth is 0 |
| Single-node tree | Depth is 1 |
| Main condition | Count nodes, not edges |
LeetCode gives the TreeNode class:
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = rightThe function shape is:
class Solution:
def maxDepth(self, root: Optional[TreeNode]) -> int:
...Examples
Consider this tree:
3
/ \
9 20
/ \
15 7The paths from root to leaf are:
3 -> 9
3 -> 20 -> 15
3 -> 20 -> 7The first path has depth 2.
The second and third paths have depth 3.
The maximum depth is:
3For this tree:
1
\
2The longest path is:
1 -> 2So the answer is:
2For an empty tree:
root = NoneThe answer is:
0First Thought: Ask Each Subtree for Its Depth
The maximum depth of a tree depends on the maximum depth of its left and right subtrees.
For any node:
depth(node) = 1 + max(depth(node.left), depth(node.right))The 1 counts the current node.
If the node is missing, its depth is 0.
This gives a natural recursive solution.
Key Insight
A binary tree has recursive structure.
Each subtree is also a binary tree.
So the same function can compute the answer for:
root
root.left
root.right
root.left.left
...The base case is simple:
if root is None:
return 0After that, compute both subtree depths:
left_depth = maxDepth(root.left)
right_depth = maxDepth(root.right)Then return the larger one plus 1:
return 1 + max(left_depth, right_depth)Algorithm
If root is None, return 0.
Otherwise:
- Recursively compute the maximum depth of the left subtree.
- Recursively compute the maximum depth of the right subtree.
- Take the larger of the two.
- Add
1for the current root node. - Return the result.
Correctness
For an empty tree, the algorithm returns 0. This matches the definition because there are no nodes.
For a non-empty tree, every path from the current node to a leaf must go through either the left child or the right child.
The longest path from the current node is therefore:
current node + longer of the two subtree pathsThe recursive calls correctly compute the maximum depth of the left and right subtrees. Taking the maximum chooses the deeper side. Adding 1 counts the current node.
This rule is applied at every node until the recursion reaches missing children. Therefore, the algorithm computes the number of nodes on the longest root-to-leaf path, which is exactly the maximum depth.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | The recursion stack stores one path at a time |
Here, n is the number of nodes and h is the height of the tree.
For a balanced tree, the space is O(log n).
For a skewed tree, the space is O(n).
Implementation
from typing import Optional
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def maxDepth(self, root: Optional[TreeNode]) -> int:
if root is None:
return 0
left_depth = self.maxDepth(root.left)
right_depth = self.maxDepth(root.right)
return 1 + max(left_depth, right_depth)Code Explanation
The base case handles an empty subtree:
if root is None:
return 0A missing node contributes no depth.
Then we compute the depth of the left subtree:
left_depth = self.maxDepth(root.left)And the depth of the right subtree:
right_depth = self.maxDepth(root.right)Only the deeper side matters for maximum depth:
max(left_depth, right_depth)Then we add 1 for the current node:
return 1 + max(left_depth, right_depth)Testing
from typing import Optional
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def maxDepth(self, root: Optional[TreeNode]) -> int:
if root is None:
return 0
left_depth = self.maxDepth(root.left)
right_depth = self.maxDepth(root.right)
return 1 + max(left_depth, right_depth)
def run_tests():
s = Solution()
root1 = TreeNode(
3,
TreeNode(9),
TreeNode(20, TreeNode(15), TreeNode(7)),
)
assert s.maxDepth(root1) == 3
root2 = TreeNode(1, None, TreeNode(2))
assert s.maxDepth(root2) == 2
root3 = None
assert s.maxDepth(root3) == 0
root4 = TreeNode(1)
assert s.maxDepth(root4) == 1
root5 = TreeNode(
1,
TreeNode(2, TreeNode(3, TreeNode(4))),
None,
)
assert s.maxDepth(root5) == 4
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
[3,9,20,null,null,15,7] | Standard multi-level tree |
[1,null,2] | Confirms right-skewed depth |
| Empty tree | Confirms base case |
| Single node | Confirms depth counts nodes |
| Left-skewed tree | Confirms longest one-sided path |