Skip to content

LeetCode 99: Recover Binary Search Tree

A detailed guide to solving Recover Binary Search Tree with inorder traversal and two misplaced nodes.

Problem Restatement

We are given the root of a binary search tree.

Exactly two node values were swapped by mistake.

We need to recover the BST without changing its structure. That means we should only swap the two wrong values back.

The official statement says exactly two nodes in a BST were swapped by mistake, and we must recover the tree without changing its structure. The number of nodes is between 2 and 1000.

Input and Output

ItemMeaning
InputRoot of a BST with two swapped values
OutputNothing returned
MutationModify the tree in-place
RuleDo not change tree structure
FixSwap the two incorrect node values

Function shape:

class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        ...

Examples

Example 1:

root = [1, 3, None, None, 2]

Tree:

    1
   /
  3
   \
    2

The node 3 cannot be in the left subtree of 1, because 3 > 1.

Swap 1 and 3.

Result:

[3, 1, None, None, 2]

Example 2:

root = [3, 1, 4, None, None, 2]

Tree:

    3
   / \
  1   4
     /
    2

The node 2 is in the right subtree of 3, but 2 < 3.

Swap 2 and 3.

Result:

[2, 1, 4, None, None, 3]

First Thought: Store and Sort Values

A valid BST has an inorder traversal in increasing order.

So one simple solution is:

  1. Traverse the tree inorder and store all values.
  2. Sort those values.
  3. Traverse the tree inorder again.
  4. Rewrite each node value from the sorted list.

This works, but it changes many node values even though only two values are wrong.

It also uses O(n) extra space.

Key Insight

Inorder traversal of a valid BST should be strictly increasing.

If two values are swapped, the inorder sequence will contain one or two order violations.

A violation happens when:

prev.val > cur.val

Example with adjacent swapped values:

correct:  [1, 2, 3, 4]
swapped:  [1, 3, 2, 4]

There is one violation:

3 > 2

The two wrong nodes are 3 and 2.

Example with non-adjacent swapped values:

correct:  [1, 2, 3, 4, 5]
swapped:  [1, 4, 3, 2, 5]

There are two violations:

4 > 3
3 > 2

The wrong nodes are the first larger node 4 and the last smaller node 2.

So during inorder traversal, we track:

VariableMeaning
prevPreviously visited node
firstFirst wrong node
secondSecond wrong node

When we find prev.val > cur.val:

  1. If first has not been set, set first = prev.
  2. Always set second = cur.

At the end, swap:

first.val, second.val = second.val, first.val

Algorithm

Perform inorder traversal.

For each visited node cur:

  1. Compare it with prev.
  2. If prev.val > cur.val, record a violation.
  3. Update prev = cur.

After traversal, swap the values of first and second.

This restores the sorted inorder order.

Walkthrough

Use:

root = [3, 1, 4, None, None, 2]

The inorder traversal is:

[1, 3, 2, 4]

Visit 1.

prev = 1

Visit 3.

1 < 3

No violation.

prev = 3

Visit 2.

3 > 2

Violation found.

Set:

first = node 3
second = node 2

Visit 4.

2 < 4

No violation.

Swap first and second.

The inorder traversal becomes:

[1, 2, 3, 4]

The BST is recovered.

Correctness

A valid BST has strictly increasing inorder traversal.

Since exactly two node values were swapped, the inorder sequence differs from a sorted sequence by exactly two misplaced values.

When the first violation prev.val > cur.val appears, prev must be one of the swapped nodes. It is too large for its current position, so the algorithm stores it as first.

The other swapped node is the value that appears too small. If the swapped nodes are adjacent in inorder order, this small value appears as cur in the same violation. If they are not adjacent, another violation appears later, and the correct small value appears as cur in the last violation.

The algorithm assigns second = cur for every violation, so after traversal second is the last too-small node.

Swapping first.val and second.val removes the inorder violations and restores the strictly increasing inorder sequence. A binary tree with strictly increasing inorder traversal and unchanged BST shape satisfies the BST property.

Therefore, the algorithm recovers the tree correctly.

Complexity

Let:

n = number of nodes
h = height of the tree
MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(h)Recursion stack depth equals tree height

The follow-up asks for O(1) extra space. That can be done with Morris inorder traversal, but the recursive version is usually the clearest first solution.

Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        prev = None
        first = None
        second = None

        def inorder(node: Optional[TreeNode]) -> None:
            nonlocal prev, first, second

            if node is None:
                return

            inorder(node.left)

            if prev is not None and prev.val > node.val:
                if first is None:
                    first = prev
                second = node

            prev = node

            inorder(node.right)

        inorder(root)

        first.val, second.val = second.val, first.val

Code Explanation

We track the previous inorder node:

prev = None

The two nodes to fix are:

first = None
second = None

The traversal is normal inorder:

inorder(node.left)
visit node
inorder(node.right)

During the visit step, detect an inversion:

if prev is not None and prev.val > node.val:

On the first violation, store the larger misplaced node:

if first is None:
    first = prev

For every violation, update the smaller misplaced node:

second = node

Then move prev forward:

prev = node

After traversal, swap the values:

first.val, second.val = second.val, first.val

Morris Traversal Version

Morris traversal performs inorder traversal with O(1) extra space by temporarily threading the tree.

class Solution:
    def recoverTree(self, root: Optional[TreeNode]) -> None:
        first = None
        second = None
        prev = None
        cur = root

        def visit(node: TreeNode) -> None:
            nonlocal first, second, prev

            if prev is not None and prev.val > node.val:
                if first is None:
                    first = prev
                second = node

            prev = node

        while cur:
            if cur.left is None:
                visit(cur)
                cur = cur.right
            else:
                pred = cur.left

                while pred.right and pred.right is not cur:
                    pred = pred.right

                if pred.right is None:
                    pred.right = cur
                    cur = cur.left
                else:
                    pred.right = None
                    visit(cur)
                    cur = cur.right

        first.val, second.val = second.val, first.val

This version restores every temporary thread before moving on.

Testing

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def inorder_values(root):
    ans = []

    def dfs(node):
        if node is None:
            return

        dfs(node.left)
        ans.append(node.val)
        dfs(node.right)

    dfs(root)
    return ans

def run_tests():
    s = Solution()

    root = TreeNode(1)
    root.left = TreeNode(3)
    root.left.right = TreeNode(2)
    s.recoverTree(root)
    assert inorder_values(root) == [1, 2, 3]

    root = TreeNode(3)
    root.left = TreeNode(1)
    root.right = TreeNode(4)
    root.right.left = TreeNode(2)
    s.recoverTree(root)
    assert inorder_values(root) == [1, 2, 3, 4]

    root = TreeNode(2)
    root.left = TreeNode(3)
    root.right = TreeNode(1)
    s.recoverTree(root)
    assert inorder_values(root) == [1, 2, 3]

    root = TreeNode(1)
    root.right = TreeNode(3)
    root.right.left = TreeNode(2)
    root.val, root.right.val = root.right.val, root.val
    s.recoverTree(root)
    assert inorder_values(root) == [1, 2, 3]

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
[1, 3, null, null, 2]Main example with adjacent inversion
[3, 1, 4, null, null, 2]Main example with deeper violation
Root children swappedNon-adjacent swapped values
Skewed shapeChecks traversal order and pointer handling