Skip to content

LeetCode 437: Path Sum III

Count downward paths in a binary tree whose values sum to targetSum using DFS and prefix sums.

Problem Restatement

We are given the root of a binary tree and an integer targetSum.

We need to count how many paths have node values that add up to targetSum.

A valid path:

RuleMeaning
Can start anywhereIt does not have to start at the root
Can end anywhereIt does not have to end at a leaf
Must go downwardIt can only move from parent to child

So this path is valid:

parent -> child -> grandchild

But this path is not valid:

left child -> parent -> right child

because it moves upward.

The official problem asks for the number of downward paths whose sum equals targetSum. The tree can be empty, and node values may be large or negative.

Input and Output

ItemMeaning
Inputroot of a binary tree and integer targetSum
OutputNumber of valid downward paths
Path startAny node
Path endAny node below the start node, including itself
Empty treeReturn 0

The node class is usually:

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

Examples

Example 1:

root = [10, 5, -3, 3, 2, None, 11, 3, -2, None, 1]
targetSum = 8

The answer is:

3

The three paths are:

5 -> 3
5 -> 2 -> 1
-3 -> 11

Example 2:

root = [5, 4, 8, 11, None, 13, 4, 7, 2, None, None, 5, 1]
targetSum = 22

The answer is:

3

Example 3:

root = None
targetSum = 0

The answer is:

0

There are no nodes, so there are no paths.

First Thought: Start DFS From Every Node

A direct approach is:

  1. Pick every node as a possible path start.
  2. From that node, try every downward path.
  3. Count paths whose sum equals targetSum.

This works, but it can revisit the same subtree many times.

For a skewed tree, the time complexity can become:

O(n^2)

We need to count all possible downward paths while visiting each node only once.

Key Insight

This problem is a tree version of subarray sum with prefix sums.

Suppose we are walking from the root to the current node and the running sum is:

current_sum

We want to know whether there is an earlier point on the same root-to-current path where the prefix sum was:

current_sum - targetSum

Why?

If:

current_sum - old_prefix_sum = targetSum

then the path after that old prefix and ending at the current node has sum targetSum.

Rearrange:

old_prefix_sum = current_sum - targetSum

So during DFS, we keep a hash map:

prefix_count[prefix_sum] = frequency

This map stores prefix sums only along the current root-to-node path.

Algorithm

Initialize:

prefix_count = {0: 1}

The prefix sum 0 handles paths that start at the root.

Run DFS with:

dfs(node, current_sum)

For each node:

  1. Add the node value to current_sum.
  2. Count how many earlier prefixes equal current_sum - targetSum.
  3. Add the current prefix sum to the hash map.
  4. Recurse into left and right children.
  5. Remove the current prefix sum from the hash map before returning.

The last step is important. Prefix sums from one branch must not affect another branch.

Correctness

At any point during DFS, prefix_count contains exactly the prefix sums on the current path from the root to the parent of the current node.

When we visit a node, we update current_sum to include that node.

A downward path ending at the current node has sum targetSum exactly when there exists an earlier prefix sum p such that:

current_sum - p == targetSum

This is equivalent to:

p == current_sum - targetSum

So prefix_count[current_sum - targetSum] gives exactly the number of valid paths ending at the current node.

After counting paths ending at the current node, we add current_sum to prefix_count before processing children. This allows child paths to start at the current node or above it.

After processing both children, we decrement the current prefix sum. This restores the hash map to its previous state before returning to the parent, so sibling branches do not share invalid prefix sums.

Because the DFS visits every node and counts exactly the valid paths ending at that node, the final total is the number of all valid downward paths.

Complexity

MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(h)Hash map and recursion stack follow the current root-to-node path

Here, n is the number of nodes.

h is the height of the tree.

In the worst case, h = n.

Implementation

from collections import defaultdict

class Solution:
    def pathSum(self, root: 'Optional[TreeNode]', targetSum: int) -> int:
        prefix_count = defaultdict(int)
        prefix_count[0] = 1

        def dfs(node: 'Optional[TreeNode]', current_sum: int) -> int:
            if not node:
                return 0

            current_sum += node.val

            total = prefix_count[current_sum - targetSum]

            prefix_count[current_sum] += 1

            total += dfs(node.left, current_sum)
            total += dfs(node.right, current_sum)

            prefix_count[current_sum] -= 1

            return total

        return dfs(root, 0)

Code Explanation

We use a hash map to count prefix sums:

prefix_count = defaultdict(int)
prefix_count[0] = 1

The initial 0 means there is one empty prefix before the root.

The DFS receives the current running sum:

def dfs(node, current_sum):

If the node is missing, it contributes no paths:

if not node:
    return 0

We include the current node:

current_sum += node.val

Now we count paths ending exactly at this node:

total = prefix_count[current_sum - targetSum]

Then we add the current prefix before going downward:

prefix_count[current_sum] += 1

Now the left and right children can use this prefix:

total += dfs(node.left, current_sum)
total += dfs(node.right, current_sum)

After both children are processed, we backtrack:

prefix_count[current_sum] -= 1

This removes the current node’s prefix from the active path.

Finally, the DFS returns how many valid paths were found in this subtree.

Testing

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def run_tests():
    s = Solution()

    root = TreeNode(
        10,
        TreeNode(
            5,
            TreeNode(
                3,
                TreeNode(3),
                TreeNode(-2),
            ),
            TreeNode(
                2,
                None,
                TreeNode(1),
            ),
        ),
        TreeNode(
            -3,
            None,
            TreeNode(11),
        ),
    )

    assert s.pathSum(root, 8) == 3

    root = TreeNode(
        5,
        TreeNode(
            4,
            TreeNode(
                11,
                TreeNode(7),
                TreeNode(2),
            ),
        ),
        TreeNode(
            8,
            TreeNode(13),
            TreeNode(
                4,
                TreeNode(5),
                TreeNode(1),
            ),
        ),
    )

    assert s.pathSum(root, 22) == 3

    assert s.pathSum(None, 0) == 0

    root = TreeNode(1, TreeNode(-1), TreeNode(1))
    assert s.pathSum(root, 0) == 1

    root = TreeNode(0, TreeNode(0), TreeNode(0))
    assert s.pathSum(root, 0) == 5

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
Standard exampleChecks paths starting below root
Target 22 exampleChecks longer downward paths
Empty treeChecks no-node case
Negative valueChecks prefix sums with subtraction
Zero valuesChecks multiple overlapping valid paths