Count downward paths in a binary tree whose values sum to targetSum using DFS and prefix sums.
Problem Restatement
We are given the root of a binary tree and an integer targetSum.
We need to count how many paths have node values that add up to targetSum.
A valid path:
| Rule | Meaning |
|---|---|
| Can start anywhere | It does not have to start at the root |
| Can end anywhere | It does not have to end at a leaf |
| Must go downward | It can only move from parent to child |
So this path is valid:
parent -> child -> grandchildBut this path is not valid:
left child -> parent -> right childbecause it moves upward.
The official problem asks for the number of downward paths whose sum equals targetSum. The tree can be empty, and node values may be large or negative.
Input and Output
| Item | Meaning |
|---|---|
| Input | root of a binary tree and integer targetSum |
| Output | Number of valid downward paths |
| Path start | Any node |
| Path end | Any node below the start node, including itself |
| Empty tree | Return 0 |
The node class is usually:
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = rightExamples
Example 1:
root = [10, 5, -3, 3, 2, None, 11, 3, -2, None, 1]
targetSum = 8The answer is:
3The three paths are:
5 -> 3
5 -> 2 -> 1
-3 -> 11Example 2:
root = [5, 4, 8, 11, None, 13, 4, 7, 2, None, None, 5, 1]
targetSum = 22The answer is:
3Example 3:
root = None
targetSum = 0The answer is:
0There are no nodes, so there are no paths.
First Thought: Start DFS From Every Node
A direct approach is:
- Pick every node as a possible path start.
- From that node, try every downward path.
- Count paths whose sum equals
targetSum.
This works, but it can revisit the same subtree many times.
For a skewed tree, the time complexity can become:
O(n^2)We need to count all possible downward paths while visiting each node only once.
Key Insight
This problem is a tree version of subarray sum with prefix sums.
Suppose we are walking from the root to the current node and the running sum is:
current_sumWe want to know whether there is an earlier point on the same root-to-current path where the prefix sum was:
current_sum - targetSumWhy?
If:
current_sum - old_prefix_sum = targetSumthen the path after that old prefix and ending at the current node has sum targetSum.
Rearrange:
old_prefix_sum = current_sum - targetSumSo during DFS, we keep a hash map:
prefix_count[prefix_sum] = frequencyThis map stores prefix sums only along the current root-to-node path.
Algorithm
Initialize:
prefix_count = {0: 1}The prefix sum 0 handles paths that start at the root.
Run DFS with:
dfs(node, current_sum)For each node:
- Add the node value to
current_sum. - Count how many earlier prefixes equal
current_sum - targetSum. - Add the current prefix sum to the hash map.
- Recurse into left and right children.
- Remove the current prefix sum from the hash map before returning.
The last step is important. Prefix sums from one branch must not affect another branch.
Correctness
At any point during DFS, prefix_count contains exactly the prefix sums on the current path from the root to the parent of the current node.
When we visit a node, we update current_sum to include that node.
A downward path ending at the current node has sum targetSum exactly when there exists an earlier prefix sum p such that:
current_sum - p == targetSumThis is equivalent to:
p == current_sum - targetSumSo prefix_count[current_sum - targetSum] gives exactly the number of valid paths ending at the current node.
After counting paths ending at the current node, we add current_sum to prefix_count before processing children. This allows child paths to start at the current node or above it.
After processing both children, we decrement the current prefix sum. This restores the hash map to its previous state before returning to the parent, so sibling branches do not share invalid prefix sums.
Because the DFS visits every node and counts exactly the valid paths ending at that node, the final total is the number of all valid downward paths.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | Hash map and recursion stack follow the current root-to-node path |
Here, n is the number of nodes.
h is the height of the tree.
In the worst case, h = n.
Implementation
from collections import defaultdict
class Solution:
def pathSum(self, root: 'Optional[TreeNode]', targetSum: int) -> int:
prefix_count = defaultdict(int)
prefix_count[0] = 1
def dfs(node: 'Optional[TreeNode]', current_sum: int) -> int:
if not node:
return 0
current_sum += node.val
total = prefix_count[current_sum - targetSum]
prefix_count[current_sum] += 1
total += dfs(node.left, current_sum)
total += dfs(node.right, current_sum)
prefix_count[current_sum] -= 1
return total
return dfs(root, 0)Code Explanation
We use a hash map to count prefix sums:
prefix_count = defaultdict(int)
prefix_count[0] = 1The initial 0 means there is one empty prefix before the root.
The DFS receives the current running sum:
def dfs(node, current_sum):If the node is missing, it contributes no paths:
if not node:
return 0We include the current node:
current_sum += node.valNow we count paths ending exactly at this node:
total = prefix_count[current_sum - targetSum]Then we add the current prefix before going downward:
prefix_count[current_sum] += 1Now the left and right children can use this prefix:
total += dfs(node.left, current_sum)
total += dfs(node.right, current_sum)After both children are processed, we backtrack:
prefix_count[current_sum] -= 1This removes the current node’s prefix from the active path.
Finally, the DFS returns how many valid paths were found in this subtree.
Testing
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def run_tests():
s = Solution()
root = TreeNode(
10,
TreeNode(
5,
TreeNode(
3,
TreeNode(3),
TreeNode(-2),
),
TreeNode(
2,
None,
TreeNode(1),
),
),
TreeNode(
-3,
None,
TreeNode(11),
),
)
assert s.pathSum(root, 8) == 3
root = TreeNode(
5,
TreeNode(
4,
TreeNode(
11,
TreeNode(7),
TreeNode(2),
),
),
TreeNode(
8,
TreeNode(13),
TreeNode(
4,
TreeNode(5),
TreeNode(1),
),
),
)
assert s.pathSum(root, 22) == 3
assert s.pathSum(None, 0) == 0
root = TreeNode(1, TreeNode(-1), TreeNode(1))
assert s.pathSum(root, 0) == 1
root = TreeNode(0, TreeNode(0), TreeNode(0))
assert s.pathSum(root, 0) == 5
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
| Standard example | Checks paths starting below root |
Target 22 example | Checks longer downward paths |
| Empty tree | Checks no-node case |
| Negative value | Checks prefix sums with subtraction |
| Zero values | Checks multiple overlapping valid paths |