A clear explanation of splitting a binary search tree into two BSTs using recursion and pointer rewiring.
Problem Restatement
We are given the root of a binary search tree and an integer target.
We need to split the tree into two separate binary search trees:
| Tree | Contains |
|---|---|
| First tree | All nodes with values <= target |
| Second tree | All nodes with values > target |
The tree may not contain a node whose value is exactly target.
We should preserve the original parent-child relationships as much as possible. Each original node appears in exactly one of the two result trees.
Return the roots of the two trees:
[small_tree_root, large_tree_root]Input and Output
| Item | Meaning |
|---|---|
| Input | root, the root of a BST, and integer target |
| Output | Two roots: one for values <= target, one for values > target |
| Constraint | Nodes are reused by rewiring pointers |
| Tree property | Both returned trees must still be valid BSTs |
Function shape:
class Solution:
def splitBST(
self,
root: Optional[TreeNode],
target: int
) -> list[Optional[TreeNode]]:
...Examples
Consider this BST:
4
/ \
2 6
/ \ / \
1 3 5 7With:
target = 2The first tree should contain values <= 2:
2
/
1The second tree should contain values > 2:
4
/ \
3 6
/ \
5 7So the result is:
[[2, 1], [4, 3, 6, None, None, 5, 7]]The exact serialized form depends on how the tree is printed, but logically the split is:
small = values <= 2
large = values > 2First Thought: Rebuild Two Trees
A direct approach is:
- Traverse every node.
- If
node.val <= target, insert it into a new small BST. - Otherwise, insert it into a new large BST.
This is easy to understand, but it does more work than needed.
It also creates new structure instead of preserving the original tree links. The problem asks us to split the existing tree, so we should reuse the original nodes.
Problem With Rebuilding
Rebuilding ignores useful structure already present in the BST.
The BST property tells us:
| Node relation | Value range |
|---|---|
| Left subtree | Smaller than the current node |
| Right subtree | Greater than the current node |
So at each node, only one side can be mixed across the split boundary.
If root.val <= target, then:
root and root.left belong to the small treeOnly root.right may contain both small and large values.
If root.val > target, then:
root and root.right belong to the large treeOnly root.left may contain both small and large values.
This lets us recurse into only one child at each step.
Key Insight
At each node, compare root.val with target.
There are two cases.
Case 1: root.val <= target
The current node belongs to the small tree.
Because this is a BST, every node in root.left is also <= root.val, so every node in root.left also belongs to the small tree.
The only uncertain part is root.right.
So we split root.right:
left_part, right_part = splitBST(root.right, target)Here:
| Result | Meaning |
|---|---|
left_part | Nodes from root.right that are <= target |
right_part | Nodes from root.right that are > target |
The left_part should remain attached to root.right.
Then root becomes the root of the small tree.
Return:
[root, right_part]Case 2: root.val > target
The current node belongs to the large tree.
Because this is a BST, every node in root.right is also greater than root.val, so every node in root.right also belongs to the large tree.
The only uncertain part is root.left.
So we split root.left:
left_part, right_part = splitBST(root.left, target)Here:
| Result | Meaning |
|---|---|
left_part | Nodes from root.left that are <= target |
right_part | Nodes from root.left that are > target |
The right_part should remain attached to root.left.
Then root becomes the root of the large tree.
Return:
[left_part, root]Algorithm
Use recursion.
If root is None, return two empty trees:
[None, None]Otherwise:
If
root.val <= target:- Split
root.right. - Attach the small part of that split to
root.right. - Return
[root, large_part].
- Split
If
root.val > target:- Split
root.left. - Attach the large part of that split to
root.left. - Return
[small_part, root].
- Split
Correctness
We prove that the algorithm returns two valid BSTs, where the first contains exactly all values <= target and the second contains exactly all values > target.
For the base case, if root is None, there are no nodes to split. Returning [None, None] is correct.
Now consider a non-empty tree.
If root.val <= target, then root belongs in the first tree. Since the input is a BST, every node in root.left has value less than root.val, so every node in root.left is also <= target. Therefore, root.left can stay attached to root.
The only subtree that may contain values on both sides of target is root.right. By recursion, splitting root.right gives two valid trees: one with values <= target, and one with values > target. We attach the first result to root.right, so root remains a valid BST containing only values <= target. The second result is returned as the large tree.
If root.val > target, then root belongs in the second tree. Since the input is a BST, every node in root.right has value greater than root.val, so every node in root.right is also > target. Therefore, root.right can stay attached to root.
The only subtree that may contain values on both sides of target is root.left. By recursion, splitting root.left gives two valid trees. We attach the second result to root.left, so root remains a valid BST containing only values > target. The first result is returned as the small tree.
In both cases, every node is returned in exactly one of the two trees, and the BST property is preserved.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(h) | We recurse down only one branch at each node |
| Space | O(h) | Recursion stack depth is the tree height |
Here, h is the height of the BST.
For a balanced tree, h = O(log n).
For a skewed tree, h = O(n).
Implementation
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def splitBST(
self,
root: Optional[TreeNode],
target: int
) -> list[Optional[TreeNode]]:
if root is None:
return [None, None]
if root.val <= target:
small, large = self.splitBST(root.right, target)
root.right = small
return [root, large]
small, large = self.splitBST(root.left, target)
root.left = large
return [small, root]Code Explanation
The empty tree case returns two empty roots:
if root is None:
return [None, None]When root.val <= target, the current node belongs to the small tree:
if root.val <= target:Its left subtree is already safe to keep because all values there are smaller than root.val.
Only the right subtree may need splitting:
small, large = self.splitBST(root.right, target)The small part of the right subtree still belongs under root:
root.right = smallThen root is the root of the small result:
return [root, large]When root.val > target, the current node belongs to the large tree.
Only the left subtree may need splitting:
small, large = self.splitBST(root.left, target)The large part of the left subtree still belongs under root:
root.left = largeThen root is the root of the large result:
return [small, root]Testing
A useful local test needs helpers to build and inspect BSTs.
from typing import Optional
class TreeNode:
def __init__(
self,
val: int = 0,
left: Optional["TreeNode"] = None,
right: Optional["TreeNode"] = None,
):
self.val = val
self.left = left
self.right = right
def inorder(root: Optional[TreeNode]) -> list[int]:
if root is None:
return []
return inorder(root.left) + [root.val] + inorder(root.right)
def collect(root: Optional[TreeNode]) -> list[int]:
if root is None:
return []
return collect(root.left) + [root.val] + collect(root.right)
def run_tests():
s = Solution()
root = TreeNode(
4,
TreeNode(2, TreeNode(1), TreeNode(3)),
TreeNode(6, TreeNode(5), TreeNode(7)),
)
small, large = s.splitBST(root, 2)
assert sorted(collect(small)) == [1, 2]
assert sorted(collect(large)) == [3, 4, 5, 6, 7]
assert inorder(small) == [1, 2]
assert inorder(large) == [3, 4, 5, 6, 7]
root = TreeNode(2, TreeNode(1), TreeNode(3))
small, large = s.splitBST(root, 5)
assert sorted(collect(small)) == [1, 2, 3]
assert collect(large) == []
root = TreeNode(2, TreeNode(1), TreeNode(3))
small, large = s.splitBST(root, 0)
assert collect(small) == []
assert sorted(collect(large)) == [1, 2, 3]
print("all tests passed")
run_tests()Test coverage:
| Test | Why |
|---|---|
| Split inside the tree | Checks normal pointer rewiring |
| Target larger than all nodes | Entire tree goes to the small result |
| Target smaller than all nodes | Entire tree goes to the large result |
| Inorder checks | Confirms each result remains a valid BST |