Skip to content

LeetCode 653: Two Sum IV - Input is a BST

A clear explanation of finding whether two different nodes in a binary search tree sum to a target value.

Problem Restatement

We are given the root of a binary search tree and an integer k.

We need to return True if there are two different nodes in the tree whose values add up to k.

Otherwise, return False.

The two values must come from two different nodes. We cannot use the same node twice.

Input and Output

ItemMeaning
InputThe root of a binary search tree and an integer k
OutputTrue if two different node values sum to k, otherwise False
Tree propertyThe input tree is a valid binary search tree
Node ruleThe two values must come from different nodes

Example function shape:

def findTarget(root: Optional[TreeNode], k: int) -> bool:
    ...

Examples

Consider this BST:

        5
       / \
      3   6
     / \   \
    2   4   7

If k = 9, the answer is True.

There are several possible pairs:

2 + 7 = 9
3 + 6 = 9

Since at least one valid pair exists, we return True.

If k = 28, the answer is False.

No two node values in the tree add up to 28.

Another example:

    2
   / \
  1   3

If k = 4, the answer is True, because:

1 + 3 = 4

If k = 1, the answer is False.

There is no pair of two different nodes that sums to 1.

First Thought: Convert to a List

Since the tree is a binary search tree, an inorder traversal gives the values in sorted order.

For this tree:

        5
       / \
      3   6
     / \   \
    2   4   7

the inorder list is:

[2, 3, 4, 5, 6, 7]

Then the problem becomes the normal Two Sum problem on a sorted array.

We can use two pointers:

  1. Put left at the beginning.
  2. Put right at the end.
  3. If values[left] + values[right] == k, return True.
  4. If the sum is too small, move left right.
  5. If the sum is too large, move right left.

This works, but it stores all values in an array.

We can solve it more directly with DFS and a hash set.

Key Insight

As we traverse the tree, suppose the current node value is x.

We need another value:

k - x

If we have already seen k - x, then we found two different nodes whose values sum to k.

This is the same idea as LeetCode 1: Two Sum.

The only difference is that the input is a tree instead of an array.

So we use a hash set called seen.

It stores node values we have already visited.

For each node:

  1. Compute the needed value.
  2. Check whether it is in seen.
  3. If yes, return True.
  4. Otherwise, add the current value and continue traversal.

Algorithm

Use DFS to visit every node.

At each node:

  1. If the node is None, return False.
  2. Let need = k - node.val.
  3. If need is in seen, return True.
  4. Add node.val to seen.
  5. Search the left subtree.
  6. Search the right subtree.
  7. Return whether either subtree contains a valid pair.

The check must happen before inserting the current value.

This prevents using the current node as its own pair when k = 2 * node.val.

Correctness

The DFS visits every node in the tree.

When the algorithm is at a node with value x, the hash set seen contains values from nodes visited earlier. These are different nodes from the current node.

If k - x is already in seen, then there is an earlier node with value k - x. Therefore, the earlier node and the current node are two different nodes, and their values sum to k.

If k - x is not in seen, then no previously visited node can pair with the current node. The algorithm adds x to seen, so later nodes can pair with it.

Since every node is visited, every possible pair will eventually be checked when the later node in that pair is processed. Therefore, the algorithm returns True exactly when such a pair exists.

Complexity

MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(n)The hash set may store all node values

The recursion stack also uses space proportional to the tree height. In the worst case, that height can be O(n).

Implementation

from typing import Optional

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def findTarget(self, root: Optional[TreeNode], k: int) -> bool:
        seen = set()

        def dfs(node: Optional[TreeNode]) -> bool:
            if node is None:
                return False

            need = k - node.val

            if need in seen:
                return True

            seen.add(node.val)

            return dfs(node.left) or dfs(node.right)

        return dfs(root)

Code Explanation

We create a set:

seen = set()

This stores values from nodes we have already visited.

The DFS function returns a boolean:

def dfs(node: Optional[TreeNode]) -> bool:

If the current node is missing, it cannot form a pair:

if node is None:
    return False

For a real node, we compute the value needed to reach k:

need = k - node.val

If that value has already appeared, then we found a valid pair:

if need in seen:
    return True

Then we add the current node value:

seen.add(node.val)

Finally, we continue into the left and right subtrees:

return dfs(node.left) or dfs(node.right)

The or short-circuits. If the left subtree already finds a pair, Python does not need to search the right subtree.

Testing

def run_tests():
    # Tree:
    #         5
    #        / \
    #       3   6
    #      / \   \
    #     2   4   7
    root = TreeNode(5)
    root.left = TreeNode(3, TreeNode(2), TreeNode(4))
    root.right = TreeNode(6, None, TreeNode(7))

    s = Solution()

    assert s.findTarget(root, 9) is True
    assert s.findTarget(root, 28) is False

    # Tree:
    #     2
    #    / \
    #   1   3
    root = TreeNode(2, TreeNode(1), TreeNode(3))

    assert s.findTarget(root, 4) is True
    assert s.findTarget(root, 1) is False
    assert s.findTarget(root, 3) is True

    # Single node cannot pair with itself.
    root = TreeNode(5)

    assert s.findTarget(root, 10) is False

    # Duplicate values in different nodes can form a valid pair.
    root = TreeNode(5)
    root.left = TreeNode(3)
    root.right = TreeNode(7)
    root.left.left = TreeNode(3)

    assert s.findTarget(root, 6) is True

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
k = 9 on the sample treeConfirms a valid pair is detected
k = 28 on the sample treeConfirms false is returned when no pair exists
Small balanced BSTChecks basic BST traversal
Single node with k = 10Confirms the same node is not reused
Duplicate values in different nodesConfirms equal values can form a pair when they are different nodes