Skip to content

LeetCode 543: Diameter of Binary Tree

A clear explanation of finding the longest path between any two nodes in a binary tree using DFS height computation.

Problem Restatement

We are given the root of a binary tree.

We need to return the length of the tree’s diameter.

The diameter is the length of the longest path between any two nodes in the tree. This path may or may not pass through the root. The length is measured by the number of edges, not the number of nodes. For example, in root = [1,2,3,4,5], one longest path is [4,2,1,3], which has 3 edges.

Input and Output

ItemMeaning
InputThe root of a binary tree
OutputThe diameter length
Path ruleThe path can start and end at any two nodes
Root ruleThe path may or may not pass through the root
Length ruleCount edges, not nodes

Function shape:

def diameterOfBinaryTree(root: Optional[TreeNode]) -> int:
    ...

Examples

Consider this tree:

      1
     / \
    2   3
   / \
  4   5

One longest path is:

4 -> 2 -> 1 -> 3

This path has 3 edges:

4 -> 2
2 -> 1
1 -> 3

So the answer is:

3

Another longest path is:

5 -> 2 -> 1 -> 3

It also has length 3.

For a smaller tree:

1
 \
  2

The longest path is:

1 -> 2

So the answer is:

1

First Thought: Compute Diameter at Every Node

A path that passes through some node uses:

longest downward path in left subtree
+
longest downward path in right subtree

So one direct approach is:

  1. For every node, compute the height of its left subtree.
  2. Compute the height of its right subtree.
  3. Add them to get the diameter passing through that node.
  4. Take the maximum over all nodes.

This is correct, but if we compute height from scratch at every node, the same subtree heights are recomputed many times.

That can lead to O(n^2) time on a skewed tree.

Key Insight

We can compute height and update diameter in the same DFS.

For each node, the DFS returns:

height of this node

where height means the number of edges in the longest downward path from this node to a leaf.

While computing height, we also check whether the longest path through the current node is the best diameter so far.

If the left subtree has height left_height and the right subtree has height right_height, then the longest path through the current node has length:

left_height + right_height

This works cleanly if we define the height of a missing child as 0 in terms of node count downward. Then a leaf returns height 1, and the diameter through a node is still left_height + right_height, counted in edges.

Algorithm

Maintain a variable:

answer

This stores the largest diameter found so far.

Define a DFS function:

height(node)

For each node:

  1. If the node is None, return 0.
  2. Recursively compute the height of the left subtree.
  3. Recursively compute the height of the right subtree.
  4. Update answer with left_height + right_height.
  5. Return 1 + max(left_height, right_height).

At the end, return answer.

Correctness

For any node, the longest path that passes through that node must go down into its left subtree, pass through the node, and then go down into its right subtree. The length of that path is the height of the left side plus the height of the right side, counted in edges.

The DFS computes the height of every subtree. Therefore, when it processes a node, it has the exact left and right heights needed to compute the best path passing through that node.

Any path in a tree has a highest node where the path turns from one side to the other, or where it continues down only one side. When the DFS processes that highest node, the path length is considered by left_height + right_height.

So every possible diameter candidate is checked.

The algorithm keeps the maximum candidate in answer, so after all nodes are processed, answer is the length of the longest path between any two nodes.

Complexity

Let n be the number of nodes and h be the height of the tree.

MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(h)Recursion stack height

In a balanced tree, h = O(log n).

In a skewed tree, h = O(n).

Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

from typing import Optional

class Solution:
    def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
        answer = 0

        def height(node: Optional[TreeNode]) -> int:
            nonlocal answer

            if node is None:
                return 0

            left_height = height(node.left)
            right_height = height(node.right)

            answer = max(answer, left_height + right_height)

            return 1 + max(left_height, right_height)

        height(root)
        return answer

Code Explanation

We start with:

answer = 0

A tree with one node has diameter 0, because there are no edges between two different nodes.

The helper function returns the height of a subtree:

def height(node: Optional[TreeNode]) -> int:

For a missing node, height is 0:

if node is None:
    return 0

Then we compute the heights of both children:

left_height = height(node.left)
right_height = height(node.right)

The longest path passing through the current node has length:

left_height + right_height

So we update the global best:

answer = max(answer, left_height + right_height)

Finally, the current node’s height is one node plus the taller child height:

return 1 + max(left_height, right_height)

Although this helper returns height in number of nodes, the diameter expression left_height + right_height gives the number of edges between the farthest left-side leaf and farthest right-side leaf.

Testing

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def run_tests():
    s = Solution()

    root = TreeNode(
        1,
        TreeNode(2, TreeNode(4), TreeNode(5)),
        TreeNode(3),
    )
    assert s.diameterOfBinaryTree(root) == 3

    root = TreeNode(1, None, TreeNode(2))
    assert s.diameterOfBinaryTree(root) == 1

    root = TreeNode(1)
    assert s.diameterOfBinaryTree(root) == 0

    root = TreeNode(
        1,
        TreeNode(
            2,
            TreeNode(3),
            None,
        ),
        None,
    )
    assert s.diameterOfBinaryTree(root) == 2

    print("all tests passed")

run_tests()
TestWhy
[1,2,3,4,5]Standard example where diameter passes through root
Two nodesChecks edge count 1
One nodeDiameter is 0
Skewed treeChecks longest path down one side