A detailed guide to solving Recover Binary Search Tree with inorder traversal and two misplaced nodes.
Problem Restatement
We are given the root of a binary search tree.
Exactly two node values were swapped by mistake.
We need to recover the BST without changing its structure. That means we should only swap the two wrong values back.
The official statement says exactly two nodes in a BST were swapped by mistake, and we must recover the tree without changing its structure. The number of nodes is between 2 and 1000.
Input and Output
| Item | Meaning |
|---|---|
| Input | Root of a BST with two swapped values |
| Output | Nothing returned |
| Mutation | Modify the tree in-place |
| Rule | Do not change tree structure |
| Fix | Swap the two incorrect node values |
Function shape:
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
...Examples
Example 1:
root = [1, 3, None, None, 2]Tree:
1
/
3
\
2The node 3 cannot be in the left subtree of 1, because 3 > 1.
Swap 1 and 3.
Result:
[3, 1, None, None, 2]Example 2:
root = [3, 1, 4, None, None, 2]Tree:
3
/ \
1 4
/
2The node 2 is in the right subtree of 3, but 2 < 3.
Swap 2 and 3.
Result:
[2, 1, 4, None, None, 3]First Thought: Store and Sort Values
A valid BST has an inorder traversal in increasing order.
So one simple solution is:
- Traverse the tree inorder and store all values.
- Sort those values.
- Traverse the tree inorder again.
- Rewrite each node value from the sorted list.
This works, but it changes many node values even though only two values are wrong.
It also uses O(n) extra space.
Key Insight
Inorder traversal of a valid BST should be strictly increasing.
If two values are swapped, the inorder sequence will contain one or two order violations.
A violation happens when:
prev.val > cur.valExample with adjacent swapped values:
correct: [1, 2, 3, 4]
swapped: [1, 3, 2, 4]There is one violation:
3 > 2The two wrong nodes are 3 and 2.
Example with non-adjacent swapped values:
correct: [1, 2, 3, 4, 5]
swapped: [1, 4, 3, 2, 5]There are two violations:
4 > 3
3 > 2The wrong nodes are the first larger node 4 and the last smaller node 2.
So during inorder traversal, we track:
| Variable | Meaning |
|---|---|
prev | Previously visited node |
first | First wrong node |
second | Second wrong node |
When we find prev.val > cur.val:
- If
firsthas not been set, setfirst = prev. - Always set
second = cur.
At the end, swap:
first.val, second.val = second.val, first.valAlgorithm
Perform inorder traversal.
For each visited node cur:
- Compare it with
prev. - If
prev.val > cur.val, record a violation. - Update
prev = cur.
After traversal, swap the values of first and second.
This restores the sorted inorder order.
Walkthrough
Use:
root = [3, 1, 4, None, None, 2]The inorder traversal is:
[1, 3, 2, 4]Visit 1.
prev = 1Visit 3.
1 < 3No violation.
prev = 3Visit 2.
3 > 2Violation found.
Set:
first = node 3
second = node 2Visit 4.
2 < 4No violation.
Swap first and second.
The inorder traversal becomes:
[1, 2, 3, 4]The BST is recovered.
Correctness
A valid BST has strictly increasing inorder traversal.
Since exactly two node values were swapped, the inorder sequence differs from a sorted sequence by exactly two misplaced values.
When the first violation prev.val > cur.val appears, prev must be one of the swapped nodes. It is too large for its current position, so the algorithm stores it as first.
The other swapped node is the value that appears too small. If the swapped nodes are adjacent in inorder order, this small value appears as cur in the same violation. If they are not adjacent, another violation appears later, and the correct small value appears as cur in the last violation.
The algorithm assigns second = cur for every violation, so after traversal second is the last too-small node.
Swapping first.val and second.val removes the inorder violations and restores the strictly increasing inorder sequence. A binary tree with strictly increasing inorder traversal and unchanged BST shape satisfies the BST property.
Therefore, the algorithm recovers the tree correctly.
Complexity
Let:
n = number of nodes
h = height of the tree| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | Recursion stack depth equals tree height |
The follow-up asks for O(1) extra space. That can be done with Morris inorder traversal, but the recursive version is usually the clearest first solution.
Implementation
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
prev = None
first = None
second = None
def inorder(node: Optional[TreeNode]) -> None:
nonlocal prev, first, second
if node is None:
return
inorder(node.left)
if prev is not None and prev.val > node.val:
if first is None:
first = prev
second = node
prev = node
inorder(node.right)
inorder(root)
first.val, second.val = second.val, first.valCode Explanation
We track the previous inorder node:
prev = NoneThe two nodes to fix are:
first = None
second = NoneThe traversal is normal inorder:
inorder(node.left)
visit node
inorder(node.right)During the visit step, detect an inversion:
if prev is not None and prev.val > node.val:On the first violation, store the larger misplaced node:
if first is None:
first = prevFor every violation, update the smaller misplaced node:
second = nodeThen move prev forward:
prev = nodeAfter traversal, swap the values:
first.val, second.val = second.val, first.valMorris Traversal Version
Morris traversal performs inorder traversal with O(1) extra space by temporarily threading the tree.
class Solution:
def recoverTree(self, root: Optional[TreeNode]) -> None:
first = None
second = None
prev = None
cur = root
def visit(node: TreeNode) -> None:
nonlocal first, second, prev
if prev is not None and prev.val > node.val:
if first is None:
first = prev
second = node
prev = node
while cur:
if cur.left is None:
visit(cur)
cur = cur.right
else:
pred = cur.left
while pred.right and pred.right is not cur:
pred = pred.right
if pred.right is None:
pred.right = cur
cur = cur.left
else:
pred.right = None
visit(cur)
cur = cur.right
first.val, second.val = second.val, first.valThis version restores every temporary thread before moving on.
Testing
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def inorder_values(root):
ans = []
def dfs(node):
if node is None:
return
dfs(node.left)
ans.append(node.val)
dfs(node.right)
dfs(root)
return ans
def run_tests():
s = Solution()
root = TreeNode(1)
root.left = TreeNode(3)
root.left.right = TreeNode(2)
s.recoverTree(root)
assert inorder_values(root) == [1, 2, 3]
root = TreeNode(3)
root.left = TreeNode(1)
root.right = TreeNode(4)
root.right.left = TreeNode(2)
s.recoverTree(root)
assert inorder_values(root) == [1, 2, 3, 4]
root = TreeNode(2)
root.left = TreeNode(3)
root.right = TreeNode(1)
s.recoverTree(root)
assert inorder_values(root) == [1, 2, 3]
root = TreeNode(1)
root.right = TreeNode(3)
root.right.left = TreeNode(2)
root.val, root.right.val = root.right.val, root.val
s.recoverTree(root)
assert inorder_values(root) == [1, 2, 3]
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
[1, 3, null, null, 2] | Main example with adjacent inversion |
[3, 1, 4, null, null, 2] | Main example with deeper violation |
| Root children swapped | Non-adjacent swapped values |
| Skewed shape | Checks traversal order and pointer handling |