Skip to content

LeetCode 538: Convert BST to Greater Tree

A clear explanation of converting a BST into a greater tree using reverse inorder traversal and a running sum.

Problem Restatement

We are given the root of a Binary Search Tree.

We need to convert it into a greater tree.

For every node, its new value should be:

original node value + sum of all values greater than it

The tree structure does not change. Only node values change.

A BST has this ordering property:

LocationValue Rule
Left subtreeSmaller than the current node
Current nodeBetween left and right subtree values
Right subtreeGreater than the current node

Because of this property, inorder traversal visits values from smallest to largest. Reverse inorder traversal visits values from largest to smallest. The official problem asks for each key to become itself plus all keys greater than it.

Input and Output

ItemMeaning
InputThe root of a Binary Search Tree
OutputThe same root after modifying node values
Tree structureUnchanged
Node update rulenode.val = original node.val + sum(values greater than original node.val)

Function shape:

def convertBST(root: Optional[TreeNode]) -> Optional[TreeNode]:
    ...

Examples

Consider this BST:

      5
     / \
    2   13

The values in sorted order are:

[2, 5, 13]

Now update each value:

Original ValueGreater ValuesNew Value
13none13
51318
25, 1320

The result is:

      18
     /  \
    20   13

For the larger example:

root = [4, 1, 6, 0, 2, 5, 7, None, None, None, 3, None, None, None, 8]

The output is:

[30, 36, 21, 36, 35, 26, 15, None, None, None, 33, None, None, None, 8]

First Thought: Collect, Sort, and Replace

A simple approach is:

  1. Traverse the tree and collect all values.
  2. Sort the values.
  3. For each original value, compute the sum of values greater than or equal to it.
  4. Traverse the tree again and replace each value.

This works, but it does extra work.

The tree is already a BST, so we do not need to sort values separately.

The BST can give us values in sorted order through traversal.

Key Insight

Inorder traversal of a BST gives values from smallest to largest:

left -> node -> right

But we need to know the sum of values greater than the current node.

So we should visit larger values first.

That means we use reverse inorder traversal:

right -> node -> left

During this traversal, we keep a running sum of all values already visited.

Since we visit from largest to smallest, the running sum always contains the current node’s original value plus all greater values after we add the current node.

Algorithm

Maintain:

running_sum

Initially:

running_sum = 0

Then run reverse inorder DFS:

  1. Visit the right subtree.
  2. Add the current node’s original value to running_sum.
  3. Replace the current node’s value with running_sum.
  4. Visit the left subtree.

The right subtree contains larger values, so it must be processed before the current node.

The left subtree contains smaller values, so it must be processed after the current node.

Correctness

Reverse inorder traversal visits BST values in descending order.

When the algorithm reaches a node, every node with a greater value has already been visited. No smaller node has been visited yet.

The variable running_sum stores the sum of all values already visited. Therefore, before updating the current node, running_sum is exactly the sum of all values greater than the current node’s original value.

The algorithm then adds the current node’s original value to running_sum and stores the result in the node.

So the new value becomes:

original node value + sum of all greater values

This is exactly the required greater tree value.

Since every node is visited once, every node is updated correctly.

Complexity

Let n be the number of nodes and h be the tree height.

MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(h)Recursion stack height

In the worst case, a skewed tree has height n.

For a balanced tree, the height is log n.

Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

from typing import Optional

class Solution:
    def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        running_sum = 0

        def dfs(node: Optional[TreeNode]) -> None:
            nonlocal running_sum

            if node is None:
                return

            dfs(node.right)

            running_sum += node.val
            node.val = running_sum

            dfs(node.left)

        dfs(root)
        return root

Code Explanation

We keep the cumulative sum in:

running_sum = 0

The DFS visits the right subtree first:

dfs(node.right)

In a BST, the right subtree contains larger values. Those values must be added before updating the current node.

Then we process the current node:

running_sum += node.val
node.val = running_sum

This works because node.val still holds the original value at the moment we add it.

Finally, we visit the left subtree:

dfs(node.left)

The left subtree contains smaller values, so those nodes should include the current node and all larger nodes in their sums.

Iterative Version

We can also write reverse inorder traversal with a stack.

from typing import Optional

class Solution:
    def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        running_sum = 0
        stack = []
        node = root

        while stack or node:
            while node:
                stack.append(node)
                node = node.right

            node = stack.pop()

            running_sum += node.val
            node.val = running_sum

            node = node.left

        return root

This version follows the same order:

right -> node -> left

It simply manages traversal manually instead of using recursion.

Testing

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def tree_to_tuple(root):
    if root is None:
        return None

    return (
        root.val,
        tree_to_tuple(root.left),
        tree_to_tuple(root.right),
    )

def run_tests():
    s = Solution()

    root = TreeNode(5, TreeNode(2), TreeNode(13))
    result = s.convertBST(root)
    assert tree_to_tuple(result) == (
        18,
        (20, None, None),
        (13, None, None),
    )

    root = TreeNode(0, None, TreeNode(1))
    result = s.convertBST(root)
    assert tree_to_tuple(result) == (
        1,
        None,
        (1, None, None),
    )

    root = TreeNode(
        4,
        TreeNode(
            1,
            TreeNode(0),
            TreeNode(2, None, TreeNode(3)),
        ),
        TreeNode(
            6,
            TreeNode(5),
            TreeNode(7, None, TreeNode(8)),
        ),
    )
    result = s.convertBST(root)
    assert tree_to_tuple(result) == (
        30,
        (36, (36, None, None), (35, None, (33, None, None))),
        (21, (26, None, None), (15, None, (8, None, None))),
    )

    root = None
    assert s.convertBST(root) is None

    print("all tests passed")

run_tests()
TestWhy
Small BSTChecks basic greater tree conversion
Right-only treeChecks descending traversal behavior
Larger sampleChecks multiple levels and mixed children
Empty treeChecks null root handling