Skip to content

LeetCode 776: Split BST

A clear explanation of splitting a binary search tree into two BSTs using recursion and pointer rewiring.

Problem Restatement

We are given the root of a binary search tree and an integer target.

We need to split the tree into two separate binary search trees:

TreeContains
First treeAll nodes with values <= target
Second treeAll nodes with values > target

The tree may not contain a node whose value is exactly target.

We should preserve the original parent-child relationships as much as possible. Each original node appears in exactly one of the two result trees.

Return the roots of the two trees:

[small_tree_root, large_tree_root]

Input and Output

ItemMeaning
Inputroot, the root of a BST, and integer target
OutputTwo roots: one for values <= target, one for values > target
ConstraintNodes are reused by rewiring pointers
Tree propertyBoth returned trees must still be valid BSTs

Function shape:

class Solution:
    def splitBST(
        self,
        root: Optional[TreeNode],
        target: int
    ) -> list[Optional[TreeNode]]:
        ...

Examples

Consider this BST:

        4
      /   \
     2     6
    / \   / \
   1   3 5   7

With:

target = 2

The first tree should contain values <= 2:

    2
   /
  1

The second tree should contain values > 2:

      4
     / \
    3   6
       / \
      5   7

So the result is:

[[2, 1], [4, 3, 6, None, None, 5, 7]]

The exact serialized form depends on how the tree is printed, but logically the split is:

small = values <= 2
large = values > 2

First Thought: Rebuild Two Trees

A direct approach is:

  1. Traverse every node.
  2. If node.val <= target, insert it into a new small BST.
  3. Otherwise, insert it into a new large BST.

This is easy to understand, but it does more work than needed.

It also creates new structure instead of preserving the original tree links. The problem asks us to split the existing tree, so we should reuse the original nodes.

Problem With Rebuilding

Rebuilding ignores useful structure already present in the BST.

The BST property tells us:

Node relationValue range
Left subtreeSmaller than the current node
Right subtreeGreater than the current node

So at each node, only one side can be mixed across the split boundary.

If root.val <= target, then:

root and root.left belong to the small tree

Only root.right may contain both small and large values.

If root.val > target, then:

root and root.right belong to the large tree

Only root.left may contain both small and large values.

This lets us recurse into only one child at each step.

Key Insight

At each node, compare root.val with target.

There are two cases.

Case 1: root.val <= target

The current node belongs to the small tree.

Because this is a BST, every node in root.left is also <= root.val, so every node in root.left also belongs to the small tree.

The only uncertain part is root.right.

So we split root.right:

left_part, right_part = splitBST(root.right, target)

Here:

ResultMeaning
left_partNodes from root.right that are <= target
right_partNodes from root.right that are > target

The left_part should remain attached to root.right.

Then root becomes the root of the small tree.

Return:

[root, right_part]

Case 2: root.val > target

The current node belongs to the large tree.

Because this is a BST, every node in root.right is also greater than root.val, so every node in root.right also belongs to the large tree.

The only uncertain part is root.left.

So we split root.left:

left_part, right_part = splitBST(root.left, target)

Here:

ResultMeaning
left_partNodes from root.left that are <= target
right_partNodes from root.left that are > target

The right_part should remain attached to root.left.

Then root becomes the root of the large tree.

Return:

[left_part, root]

Algorithm

Use recursion.

If root is None, return two empty trees:

[None, None]

Otherwise:

  1. If root.val <= target:

    1. Split root.right.
    2. Attach the small part of that split to root.right.
    3. Return [root, large_part].
  2. If root.val > target:

    1. Split root.left.
    2. Attach the large part of that split to root.left.
    3. Return [small_part, root].

Correctness

We prove that the algorithm returns two valid BSTs, where the first contains exactly all values <= target and the second contains exactly all values > target.

For the base case, if root is None, there are no nodes to split. Returning [None, None] is correct.

Now consider a non-empty tree.

If root.val <= target, then root belongs in the first tree. Since the input is a BST, every node in root.left has value less than root.val, so every node in root.left is also <= target. Therefore, root.left can stay attached to root.

The only subtree that may contain values on both sides of target is root.right. By recursion, splitting root.right gives two valid trees: one with values <= target, and one with values > target. We attach the first result to root.right, so root remains a valid BST containing only values <= target. The second result is returned as the large tree.

If root.val > target, then root belongs in the second tree. Since the input is a BST, every node in root.right has value greater than root.val, so every node in root.right is also > target. Therefore, root.right can stay attached to root.

The only subtree that may contain values on both sides of target is root.left. By recursion, splitting root.left gives two valid trees. We attach the second result to root.left, so root remains a valid BST containing only values > target. The first result is returned as the small tree.

In both cases, every node is returned in exactly one of the two trees, and the BST property is preserved.

Complexity

MetricValueWhy
TimeO(h)We recurse down only one branch at each node
SpaceO(h)Recursion stack depth is the tree height

Here, h is the height of the BST.

For a balanced tree, h = O(log n).

For a skewed tree, h = O(n).

Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def splitBST(
        self,
        root: Optional[TreeNode],
        target: int
    ) -> list[Optional[TreeNode]]:
        if root is None:
            return [None, None]

        if root.val <= target:
            small, large = self.splitBST(root.right, target)
            root.right = small
            return [root, large]

        small, large = self.splitBST(root.left, target)
        root.left = large
        return [small, root]

Code Explanation

The empty tree case returns two empty roots:

if root is None:
    return [None, None]

When root.val <= target, the current node belongs to the small tree:

if root.val <= target:

Its left subtree is already safe to keep because all values there are smaller than root.val.

Only the right subtree may need splitting:

small, large = self.splitBST(root.right, target)

The small part of the right subtree still belongs under root:

root.right = small

Then root is the root of the small result:

return [root, large]

When root.val > target, the current node belongs to the large tree.

Only the left subtree may need splitting:

small, large = self.splitBST(root.left, target)

The large part of the left subtree still belongs under root:

root.left = large

Then root is the root of the large result:

return [small, root]

Testing

A useful local test needs helpers to build and inspect BSTs.

from typing import Optional

class TreeNode:
    def __init__(
        self,
        val: int = 0,
        left: Optional["TreeNode"] = None,
        right: Optional["TreeNode"] = None,
    ):
        self.val = val
        self.left = left
        self.right = right

def inorder(root: Optional[TreeNode]) -> list[int]:
    if root is None:
        return []

    return inorder(root.left) + [root.val] + inorder(root.right)

def collect(root: Optional[TreeNode]) -> list[int]:
    if root is None:
        return []

    return collect(root.left) + [root.val] + collect(root.right)

def run_tests():
    s = Solution()

    root = TreeNode(
        4,
        TreeNode(2, TreeNode(1), TreeNode(3)),
        TreeNode(6, TreeNode(5), TreeNode(7)),
    )

    small, large = s.splitBST(root, 2)

    assert sorted(collect(small)) == [1, 2]
    assert sorted(collect(large)) == [3, 4, 5, 6, 7]

    assert inorder(small) == [1, 2]
    assert inorder(large) == [3, 4, 5, 6, 7]

    root = TreeNode(2, TreeNode(1), TreeNode(3))
    small, large = s.splitBST(root, 5)

    assert sorted(collect(small)) == [1, 2, 3]
    assert collect(large) == []

    root = TreeNode(2, TreeNode(1), TreeNode(3))
    small, large = s.splitBST(root, 0)

    assert collect(small) == []
    assert sorted(collect(large)) == [1, 2, 3]

    print("all tests passed")

run_tests()

Test coverage:

TestWhy
Split inside the treeChecks normal pointer rewiring
Target larger than all nodesEntire tree goes to the small result
Target smaller than all nodesEntire tree goes to the large result
Inorder checksConfirms each result remains a valid BST