A clear explanation of converting a BST into a greater tree using reverse inorder traversal and a running sum.
Problem Restatement
We are given the root of a Binary Search Tree.
We need to convert it into a greater tree.
For every node, its new value should be:
original node value + sum of all values greater than itThe tree structure does not change. Only node values change.
A BST has this ordering property:
| Location | Value Rule |
|---|---|
| Left subtree | Smaller than the current node |
| Current node | Between left and right subtree values |
| Right subtree | Greater than the current node |
Because of this property, inorder traversal visits values from smallest to largest. Reverse inorder traversal visits values from largest to smallest. The official problem asks for each key to become itself plus all keys greater than it.
Input and Output
| Item | Meaning |
|---|---|
| Input | The root of a Binary Search Tree |
| Output | The same root after modifying node values |
| Tree structure | Unchanged |
| Node update rule | node.val = original node.val + sum(values greater than original node.val) |
Function shape:
def convertBST(root: Optional[TreeNode]) -> Optional[TreeNode]:
...Examples
Consider this BST:
5
/ \
2 13The values in sorted order are:
[2, 5, 13]Now update each value:
| Original Value | Greater Values | New Value |
|---|---|---|
13 | none | 13 |
5 | 13 | 18 |
2 | 5, 13 | 20 |
The result is:
18
/ \
20 13For the larger example:
root = [4, 1, 6, 0, 2, 5, 7, None, None, None, 3, None, None, None, 8]The output is:
[30, 36, 21, 36, 35, 26, 15, None, None, None, 33, None, None, None, 8]First Thought: Collect, Sort, and Replace
A simple approach is:
- Traverse the tree and collect all values.
- Sort the values.
- For each original value, compute the sum of values greater than or equal to it.
- Traverse the tree again and replace each value.
This works, but it does extra work.
The tree is already a BST, so we do not need to sort values separately.
The BST can give us values in sorted order through traversal.
Key Insight
Inorder traversal of a BST gives values from smallest to largest:
left -> node -> rightBut we need to know the sum of values greater than the current node.
So we should visit larger values first.
That means we use reverse inorder traversal:
right -> node -> leftDuring this traversal, we keep a running sum of all values already visited.
Since we visit from largest to smallest, the running sum always contains the current node’s original value plus all greater values after we add the current node.
Algorithm
Maintain:
running_sumInitially:
running_sum = 0Then run reverse inorder DFS:
- Visit the right subtree.
- Add the current node’s original value to
running_sum. - Replace the current node’s value with
running_sum. - Visit the left subtree.
The right subtree contains larger values, so it must be processed before the current node.
The left subtree contains smaller values, so it must be processed after the current node.
Correctness
Reverse inorder traversal visits BST values in descending order.
When the algorithm reaches a node, every node with a greater value has already been visited. No smaller node has been visited yet.
The variable running_sum stores the sum of all values already visited. Therefore, before updating the current node, running_sum is exactly the sum of all values greater than the current node’s original value.
The algorithm then adds the current node’s original value to running_sum and stores the result in the node.
So the new value becomes:
original node value + sum of all greater valuesThis is exactly the required greater tree value.
Since every node is visited once, every node is updated correctly.
Complexity
Let n be the number of nodes and h be the tree height.
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | Recursion stack height |
In the worst case, a skewed tree has height n.
For a balanced tree, the height is log n.
Implementation
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
from typing import Optional
class Solution:
def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
running_sum = 0
def dfs(node: Optional[TreeNode]) -> None:
nonlocal running_sum
if node is None:
return
dfs(node.right)
running_sum += node.val
node.val = running_sum
dfs(node.left)
dfs(root)
return rootCode Explanation
We keep the cumulative sum in:
running_sum = 0The DFS visits the right subtree first:
dfs(node.right)In a BST, the right subtree contains larger values. Those values must be added before updating the current node.
Then we process the current node:
running_sum += node.val
node.val = running_sumThis works because node.val still holds the original value at the moment we add it.
Finally, we visit the left subtree:
dfs(node.left)The left subtree contains smaller values, so those nodes should include the current node and all larger nodes in their sums.
Iterative Version
We can also write reverse inorder traversal with a stack.
from typing import Optional
class Solution:
def convertBST(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
running_sum = 0
stack = []
node = root
while stack or node:
while node:
stack.append(node)
node = node.right
node = stack.pop()
running_sum += node.val
node.val = running_sum
node = node.left
return rootThis version follows the same order:
right -> node -> leftIt simply manages traversal manually instead of using recursion.
Testing
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def tree_to_tuple(root):
if root is None:
return None
return (
root.val,
tree_to_tuple(root.left),
tree_to_tuple(root.right),
)
def run_tests():
s = Solution()
root = TreeNode(5, TreeNode(2), TreeNode(13))
result = s.convertBST(root)
assert tree_to_tuple(result) == (
18,
(20, None, None),
(13, None, None),
)
root = TreeNode(0, None, TreeNode(1))
result = s.convertBST(root)
assert tree_to_tuple(result) == (
1,
None,
(1, None, None),
)
root = TreeNode(
4,
TreeNode(
1,
TreeNode(0),
TreeNode(2, None, TreeNode(3)),
),
TreeNode(
6,
TreeNode(5),
TreeNode(7, None, TreeNode(8)),
),
)
result = s.convertBST(root)
assert tree_to_tuple(result) == (
30,
(36, (36, None, None), (35, None, (33, None, None))),
(21, (26, None, None), (15, None, (8, None, None))),
)
root = None
assert s.convertBST(root) is None
print("all tests passed")
run_tests()| Test | Why |
|---|---|
| Small BST | Checks basic greater tree conversion |
| Right-only tree | Checks descending traversal behavior |
| Larger sample | Checks multiple levels and mixed children |
| Empty tree | Checks null root handling |