Skip to content

LeetCode 104: Maximum Depth of Binary Tree

A clear explanation of finding the maximum depth of a binary tree using recursive depth-first search.

Problem Restatement

We are given the root of a binary tree.

We need to return its maximum depth.

The maximum depth is the number of nodes along the longest path from the root node down to the farthest leaf node. The official statement uses this same definition.

For this tree:

        3
      /   \
     9     20
          /  \
         15   7

The longest root-to-leaf path is:

3 -> 20 -> 15

or:

3 -> 20 -> 7

Both paths contain 3 nodes.

So the answer is:

3

Input and Output

ItemMeaning
Inputroot, the root node of a binary tree
OutputAn integer depth
Empty treeDepth is 0
Single-node treeDepth is 1
Main conditionCount nodes, not edges

LeetCode gives the TreeNode class:

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

The function shape is:

class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        ...

Examples

Consider this tree:

        3
      /   \
     9     20
          /  \
         15   7

The paths from root to leaf are:

3 -> 9
3 -> 20 -> 15
3 -> 20 -> 7

The first path has depth 2.

The second and third paths have depth 3.

The maximum depth is:

3

For this tree:

1
 \
  2

The longest path is:

1 -> 2

So the answer is:

2

For an empty tree:

root = None

The answer is:

0

First Thought: Ask Each Subtree for Its Depth

The maximum depth of a tree depends on the maximum depth of its left and right subtrees.

For any node:

depth(node) = 1 + max(depth(node.left), depth(node.right))

The 1 counts the current node.

If the node is missing, its depth is 0.

This gives a natural recursive solution.

Key Insight

A binary tree has recursive structure.

Each subtree is also a binary tree.

So the same function can compute the answer for:

root
root.left
root.right
root.left.left
...

The base case is simple:

if root is None:
    return 0

After that, compute both subtree depths:

left_depth = maxDepth(root.left)
right_depth = maxDepth(root.right)

Then return the larger one plus 1:

return 1 + max(left_depth, right_depth)

Algorithm

If root is None, return 0.

Otherwise:

  1. Recursively compute the maximum depth of the left subtree.
  2. Recursively compute the maximum depth of the right subtree.
  3. Take the larger of the two.
  4. Add 1 for the current root node.
  5. Return the result.

Correctness

For an empty tree, the algorithm returns 0. This matches the definition because there are no nodes.

For a non-empty tree, every path from the current node to a leaf must go through either the left child or the right child.

The longest path from the current node is therefore:

current node + longer of the two subtree paths

The recursive calls correctly compute the maximum depth of the left and right subtrees. Taking the maximum chooses the deeper side. Adding 1 counts the current node.

This rule is applied at every node until the recursion reaches missing children. Therefore, the algorithm computes the number of nodes on the longest root-to-leaf path, which is exactly the maximum depth.

Complexity

MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(h)The recursion stack stores one path at a time

Here, n is the number of nodes and h is the height of the tree.

For a balanced tree, the space is O(log n).

For a skewed tree, the space is O(n).

Implementation

from typing import Optional

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        if root is None:
            return 0

        left_depth = self.maxDepth(root.left)
        right_depth = self.maxDepth(root.right)

        return 1 + max(left_depth, right_depth)

Code Explanation

The base case handles an empty subtree:

if root is None:
    return 0

A missing node contributes no depth.

Then we compute the depth of the left subtree:

left_depth = self.maxDepth(root.left)

And the depth of the right subtree:

right_depth = self.maxDepth(root.right)

Only the deeper side matters for maximum depth:

max(left_depth, right_depth)

Then we add 1 for the current node:

return 1 + max(left_depth, right_depth)

Testing

from typing import Optional

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        if root is None:
            return 0

        left_depth = self.maxDepth(root.left)
        right_depth = self.maxDepth(root.right)

        return 1 + max(left_depth, right_depth)

def run_tests():
    s = Solution()

    root1 = TreeNode(
        3,
        TreeNode(9),
        TreeNode(20, TreeNode(15), TreeNode(7)),
    )
    assert s.maxDepth(root1) == 3

    root2 = TreeNode(1, None, TreeNode(2))
    assert s.maxDepth(root2) == 2

    root3 = None
    assert s.maxDepth(root3) == 0

    root4 = TreeNode(1)
    assert s.maxDepth(root4) == 1

    root5 = TreeNode(
        1,
        TreeNode(2, TreeNode(3, TreeNode(4))),
        None,
    )
    assert s.maxDepth(root5) == 4

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
[3,9,20,null,null,15,7]Standard multi-level tree
[1,null,2]Confirms right-skewed depth
Empty treeConfirms base case
Single nodeConfirms depth counts nodes
Left-skewed treeConfirms longest one-sided path