A clear explanation of finding the kth smallest value in a binary search tree using inorder traversal.
Problem Restatement
We are given the root of a binary search tree and an integer k.
We need to return the kth smallest value among all node values in the tree.
The counting is 1-indexed. So:
k = 1 means the smallest value
k = 2 means the second smallest value
k = 3 means the third smallest valueLeetCode states that 1 <= k <= n <= 10^4, where n is the number of nodes in the tree. Node values are between 0 and 10^4. The follow-up asks how to optimize when the BST is modified often and kth-smallest queries are frequent.
Input and Output
| Item | Meaning |
|---|---|
| Input | root of a binary search tree and integer k |
| Output | The kth smallest node value |
| Counting | 1-indexed |
| Constraint | 1 <= k <= number of nodes |
Function shape:
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
...Examples
Example 1:
Input: root = [3,1,4,null,2], k = 1
Output: 1The tree is:
3
/ \
1 4
\
2The values in sorted order are:
[1, 2, 3, 4]The first smallest value is:
1Example 2:
Input: root = [5,3,6,2,4,null,null,1], k = 3
Output: 3The tree is:
5
/ \
3 6
/ \
2 4
/
1The values in sorted order are:
[1, 2, 3, 4, 5, 6]The third smallest value is:
3First Thought: Collect and Sort
A direct solution is:
- Traverse the whole tree.
- Store every value in an array.
- Sort the array.
- Return
arr[k - 1].
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
values = []
def dfs(node):
if node is None:
return
values.append(node.val)
dfs(node.left)
dfs(node.right)
dfs(root)
values.sort()
return values[k - 1]This works, but it wastes the binary search tree property.
A BST already stores values in an ordered structure.
Key Insight
In a binary search tree:
left subtree values < root value < right subtree valuesSo an inorder traversal visits values in ascending order.
Inorder traversal means:
left subtree
current node
right subtreeTherefore, if we perform inorder traversal and count visited nodes, the kth visited node is the answer.
We do not need to sort anything.
We also do not need to visit the whole tree if we find the answer early.
Algorithm
Use iterative inorder traversal with a stack.
Maintain:
| Variable | Meaning |
|---|---|
stack | Nodes waiting to be visited |
cur | Current node |
k | Number of remaining nodes before the answer |
Steps:
- Start with
cur = root. - Move left as far as possible, pushing nodes into
stack. - Pop one node from
stack. - This popped node is the next smallest value.
- Decrease
k. - If
k == 0, return this node’s value. - Move to the popped node’s right child.
- Repeat.
Correctness
The algorithm performs an inorder traversal of the BST.
For any node, all values in its left subtree are smaller than the node’s value, and all values in its right subtree are larger than the node’s value.
So visiting left subtree first, then the node, then the right subtree visits values in strictly increasing order.
The stack simulates this traversal.
Each time we pop from the stack, we visit the next smallest unvisited node.
The first popped node is the smallest value.
The second popped node is the second smallest value.
In general, the kth popped node is the kth smallest value.
The algorithm decreases k exactly once per visited node. When k becomes 0, the current node is exactly the original kth visited node in ascending order.
Therefore, returning cur.val at that moment is correct.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(h + k) | We walk down the tree height, then visit k nodes |
| Space | O(h) | The stack stores at most one path from root to leaf |
Here, h is the tree height.
In the worst case, h = O(n) for a skewed tree.
For a balanced tree, h = O(log n).
Implementation
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
stack = []
cur = root
while cur is not None or stack:
while cur is not None:
stack.append(cur)
cur = cur.left
cur = stack.pop()
k -= 1
if k == 0:
return cur.val
cur = cur.rightCode Explanation
We start at the root:
stack = []
cur = rootThe outer loop continues while there are still nodes to process:
while cur is not None or stack:Then we go as far left as possible:
while cur is not None:
stack.append(cur)
cur = cur.leftThis brings us to the smallest unvisited node.
Then we pop from the stack:
cur = stack.pop()This node is the next value in ascending order.
So we count it:
k -= 1If this is the kth smallest value, return it:
if k == 0:
return cur.valOtherwise, move to the right subtree:
cur = cur.rightThe right subtree contains the next larger values.
Follow-Up: Frequent Updates and Queries
If the BST is modified often and kth-smallest queries are frequent, plain inorder traversal may be too slow.
A better design is to store subtree sizes in each node.
For each node:
node.size = number of nodes in the subtree rooted at nodeThen kth-smallest can be found like this:
- Let
left_sizebe the size of the left subtree. - If
k == left_size + 1, the current node is the answer. - If
k <= left_size, go left. - Otherwise, go right with
k = k - left_size - 1.
This gives O(h) query time.
If the tree is balanced, this becomes O(log n).
Insert and delete operations must update subtree sizes along the changed path.
Testing
from collections import deque
from typing import Optional
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def build_tree(values):
if not values:
return None
root = TreeNode(values[0])
q = deque([root])
i = 1
while q and i < len(values):
node = q.popleft()
if i < len(values) and values[i] is not None:
node.left = TreeNode(values[i])
q.append(node.left)
i += 1
if i < len(values) and values[i] is not None:
node.right = TreeNode(values[i])
q.append(node.right)
i += 1
return rootdef run_tests():
s = Solution()
root = build_tree([3, 1, 4, None, 2])
assert s.kthSmallest(root, 1) == 1
root = build_tree([5, 3, 6, 2, 4, None, None, 1])
assert s.kthSmallest(root, 3) == 3
root = build_tree([2, 1, 3])
assert s.kthSmallest(root, 2) == 2
root = build_tree([1])
assert s.kthSmallest(root, 1) == 1
root = build_tree([4, 2, 5, 1, 3])
assert s.kthSmallest(root, 5) == 5
print("all tests passed")
run_tests()| Test | Why |
|---|---|
[3,1,4,null,2], k = 1 | Smallest value is deep enough to require left traversal |
[5,3,6,2,4,null,null,1], k = 3 | Official-style medium tree |
[2,1,3], k = 2 | Root is the answer |
[1], k = 1 | Single-node tree |
[4,2,5,1,3], k = 5 | Largest value |