A clear explanation of finding whether two different nodes in a binary search tree sum to a target value.
Problem Restatement
We are given the root of a binary search tree and an integer k.
We need to return True if there are two different nodes in the tree whose values add up to k.
Otherwise, return False.
The two values must come from two different nodes. We cannot use the same node twice.
Input and Output
| Item | Meaning |
|---|---|
| Input | The root of a binary search tree and an integer k |
| Output | True if two different node values sum to k, otherwise False |
| Tree property | The input tree is a valid binary search tree |
| Node rule | The two values must come from different nodes |
Example function shape:
def findTarget(root: Optional[TreeNode], k: int) -> bool:
...Examples
Consider this BST:
5
/ \
3 6
/ \ \
2 4 7If k = 9, the answer is True.
There are several possible pairs:
2 + 7 = 9
3 + 6 = 9Since at least one valid pair exists, we return True.
If k = 28, the answer is False.
No two node values in the tree add up to 28.
Another example:
2
/ \
1 3If k = 4, the answer is True, because:
1 + 3 = 4If k = 1, the answer is False.
There is no pair of two different nodes that sums to 1.
First Thought: Convert to a List
Since the tree is a binary search tree, an inorder traversal gives the values in sorted order.
For this tree:
5
/ \
3 6
/ \ \
2 4 7the inorder list is:
[2, 3, 4, 5, 6, 7]Then the problem becomes the normal Two Sum problem on a sorted array.
We can use two pointers:
- Put
leftat the beginning. - Put
rightat the end. - If
values[left] + values[right] == k, returnTrue. - If the sum is too small, move
leftright. - If the sum is too large, move
rightleft.
This works, but it stores all values in an array.
We can solve it more directly with DFS and a hash set.
Key Insight
As we traverse the tree, suppose the current node value is x.
We need another value:
k - xIf we have already seen k - x, then we found two different nodes whose values sum to k.
This is the same idea as LeetCode 1: Two Sum.
The only difference is that the input is a tree instead of an array.
So we use a hash set called seen.
It stores node values we have already visited.
For each node:
- Compute the needed value.
- Check whether it is in
seen. - If yes, return
True. - Otherwise, add the current value and continue traversal.
Algorithm
Use DFS to visit every node.
At each node:
- If the node is
None, returnFalse. - Let
need = k - node.val. - If
needis inseen, returnTrue. - Add
node.valtoseen. - Search the left subtree.
- Search the right subtree.
- Return whether either subtree contains a valid pair.
The check must happen before inserting the current value.
This prevents using the current node as its own pair when k = 2 * node.val.
Correctness
The DFS visits every node in the tree.
When the algorithm is at a node with value x, the hash set seen contains values from nodes visited earlier. These are different nodes from the current node.
If k - x is already in seen, then there is an earlier node with value k - x. Therefore, the earlier node and the current node are two different nodes, and their values sum to k.
If k - x is not in seen, then no previously visited node can pair with the current node. The algorithm adds x to seen, so later nodes can pair with it.
Since every node is visited, every possible pair will eventually be checked when the later node in that pair is processed. Therefore, the algorithm returns True exactly when such a pair exists.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(n) | The hash set may store all node values |
The recursion stack also uses space proportional to the tree height. In the worst case, that height can be O(n).
Implementation
from typing import Optional
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def findTarget(self, root: Optional[TreeNode], k: int) -> bool:
seen = set()
def dfs(node: Optional[TreeNode]) -> bool:
if node is None:
return False
need = k - node.val
if need in seen:
return True
seen.add(node.val)
return dfs(node.left) or dfs(node.right)
return dfs(root)Code Explanation
We create a set:
seen = set()This stores values from nodes we have already visited.
The DFS function returns a boolean:
def dfs(node: Optional[TreeNode]) -> bool:If the current node is missing, it cannot form a pair:
if node is None:
return FalseFor a real node, we compute the value needed to reach k:
need = k - node.valIf that value has already appeared, then we found a valid pair:
if need in seen:
return TrueThen we add the current node value:
seen.add(node.val)Finally, we continue into the left and right subtrees:
return dfs(node.left) or dfs(node.right)The or short-circuits. If the left subtree already finds a pair, Python does not need to search the right subtree.
Testing
def run_tests():
# Tree:
# 5
# / \
# 3 6
# / \ \
# 2 4 7
root = TreeNode(5)
root.left = TreeNode(3, TreeNode(2), TreeNode(4))
root.right = TreeNode(6, None, TreeNode(7))
s = Solution()
assert s.findTarget(root, 9) is True
assert s.findTarget(root, 28) is False
# Tree:
# 2
# / \
# 1 3
root = TreeNode(2, TreeNode(1), TreeNode(3))
assert s.findTarget(root, 4) is True
assert s.findTarget(root, 1) is False
assert s.findTarget(root, 3) is True
# Single node cannot pair with itself.
root = TreeNode(5)
assert s.findTarget(root, 10) is False
# Duplicate values in different nodes can form a valid pair.
root = TreeNode(5)
root.left = TreeNode(3)
root.right = TreeNode(7)
root.left.left = TreeNode(3)
assert s.findTarget(root, 6) is True
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
k = 9 on the sample tree | Confirms a valid pair is detected |
k = 28 on the sample tree | Confirms false is returned when no pair exists |
| Small balanced BST | Checks basic BST traversal |
Single node with k = 10 | Confirms the same node is not reused |
| Duplicate values in different nodes | Confirms equal values can form a pair when they are different nodes |