A clear explanation of finding the maximum path sum in a binary tree using bottom-up depth-first search.
Problem Restatement
We are given the root of a binary tree.
We need to find the maximum path sum of any non-empty path in the tree. A path is a sequence of connected nodes where each pair of adjacent nodes has an edge between them. A node may appear in the path at most once. The path does not need to pass through the root. The official problem asks for the maximum sum among all such paths.
For this tree:
1
/ \
2 3The best path is:
2 -> 1 -> 3The sum is:
2 + 1 + 3 = 6So the answer is:
6Input and Output
| Item | Meaning |
|---|---|
| Input | root, the root of a binary tree |
| Output | Maximum path sum |
| Path | Connected sequence of nodes |
| Root required | No |
| Empty path allowed | No |
| Node reuse | Not allowed |
LeetCode gives the TreeNode class:
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = rightThe function shape is:
class Solution:
def maxPathSum(self, root: Optional[TreeNode]) -> int:
...Examples
Consider:
1
/ \
2 3The best path uses all three nodes:
2 -> 1 -> 3Its sum is:
6Now consider:
-10
/ \
9 20
/ \
15 7The best path is:
15 -> 20 -> 7Its sum is:
42The path does not include the root -10.
So the answer is:
42First Thought: A Path Can Bend Once
A path in a binary tree can move from one child up through a node and down into another child.
For example:
left child path -> node -> right child pathThis is a valid path.
But once a path bends through a node, it cannot continue upward to the parent with both branches. That would create a fork, not a single path.
So for each node, we need to distinguish two values:
| Value | Meaning |
|---|---|
| Best path through this node | May use left branch and right branch |
| Best gain returned to parent | Must use only one side |
This distinction is the core of the problem.
Key Insight
When DFS visits a node, it asks each child:
“What is the best downward path gain you can contribute to me?”
If a child contributes a negative value, we should ignore it.
So we use:
left_gain = max(0, dfs(node.left))
right_gain = max(0, dfs(node.right))Then the best path that bends through the current node is:
node.val + left_gain + right_gainThis value can update the global answer.
But the value returned to the parent must be a single downward chain:
node.val + max(left_gain, right_gain)The parent can attach this chain to its own path.
Algorithm
Use DFS.
Maintain a global variable:
bestwhere best stores the maximum path sum found so far.
For each node:
- Recursively compute the best downward gain from the left child.
- Recursively compute the best downward gain from the right child.
- Replace negative gains with
0. - Compute the best path passing through the current node.
- Update
best. - Return the best one-sided gain to the parent.
The final answer is best.
Correctness
For each node, the algorithm computes the maximum downward path gain that starts at that node and continues into at most one child subtree.
This return value must be one-sided because a parent can connect to the current node through only one path. If the current node returned both left and right branches, the parent path would branch into three directions, which would violate the path definition.
The algorithm also considers the best complete path whose highest node is the current node:
node.val + left_gain + right_gainThis path may use both children because it ends inside the left subtree and inside the right subtree, passing through the current node exactly once.
Every valid path in a binary tree has some highest node, meaning the node closest to the root among the nodes in that path. At that highest node, the path can include at most one downward chain from the left subtree and at most one downward chain from the right subtree. The algorithm evaluates exactly that form for every node.
Negative child gains are discarded because adding a negative chain can only reduce the path sum.
Therefore, every possible valid path is considered, and best stores the maximum sum among them. The algorithm returns the correct maximum path sum.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | The recursion stack follows one root-to-leaf path |
Here, n is the number of nodes and h is the height of the tree.
For a balanced tree, h = O(log n).
For a skewed tree, h = O(n).
Implementation
from typing import Optional
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def maxPathSum(self, root: Optional[TreeNode]) -> int:
best = float("-inf")
def dfs(node: Optional[TreeNode]) -> int:
nonlocal best
if node is None:
return 0
left_gain = max(0, dfs(node.left))
right_gain = max(0, dfs(node.right))
path_sum = node.val + left_gain + right_gain
best = max(best, path_sum)
return node.val + max(left_gain, right_gain)
dfs(root)
return bestCode Explanation
Initialize the best answer to negative infinity:
best = float("-inf")This matters because all node values may be negative.
The DFS function returns the best one-sided gain:
def dfs(node: Optional[TreeNode]) -> int:An empty child contributes no gain:
if node is None:
return 0Compute gains from children and discard negative gains:
left_gain = max(0, dfs(node.left))
right_gain = max(0, dfs(node.right))Compute the best path that passes through the current node:
path_sum = node.val + left_gain + right_gainUpdate the global best answer:
best = max(best, path_sum)Return only one side to the parent:
return node.val + max(left_gain, right_gain)Finally:
return bestTesting
from typing import Optional
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def maxPathSum(self, root: Optional[TreeNode]) -> int:
best = float("-inf")
def dfs(node: Optional[TreeNode]) -> int:
nonlocal best
if node is None:
return 0
left_gain = max(0, dfs(node.left))
right_gain = max(0, dfs(node.right))
path_sum = node.val + left_gain + right_gain
best = max(best, path_sum)
return node.val + max(left_gain, right_gain)
dfs(root)
return best
def run_tests():
s = Solution()
root1 = TreeNode(1, TreeNode(2), TreeNode(3))
assert s.maxPathSum(root1) == 6
root2 = TreeNode(
-10,
TreeNode(9),
TreeNode(20, TreeNode(15), TreeNode(7)),
)
assert s.maxPathSum(root2) == 42
root3 = TreeNode(-3)
assert s.maxPathSum(root3) == -3
root4 = TreeNode(2, TreeNode(-1), None)
assert s.maxPathSum(root4) == 2
root5 = TreeNode(
5,
TreeNode(4),
TreeNode(8, TreeNode(11), TreeNode(-2)),
)
assert s.maxPathSum(root5) == 28
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
[1,2,3] | Best path bends through root |
[-10,9,20,null,null,15,7] | Best path avoids root |
| Single negative node | Confirms non-empty path rule |
| Negative child | Confirms negative gains are ignored |
| Mixed tree | Confirms best path can pass through an internal node |