A clear explanation of designing an iterator over a BST using controlled inorder traversal with a stack.
Problem Restatement
We need to implement a class:
BSTIteratorIt represents an iterator over the inorder traversal of a binary search tree.
For a BST, inorder traversal visits values in ascending order.
The class must support:
| Method | Meaning |
|---|---|
BSTIterator(root) | Initialize the iterator with the root of the BST |
next() | Move to the next value and return it |
hasNext() | Return whether there is still a next value |
The pointer starts before the smallest element, so the first call to next() returns the smallest value in the BST.
LeetCode states that next() calls are always valid, meaning next() is called only when a next value exists.
Input and Output
| Item | Meaning |
|---|---|
| Input | Root of a binary search tree |
Output from next() | Next smallest value |
Output from hasNext() | Boolean |
| Traversal order | Inorder: left, root, right |
| Important property | Inorder traversal of a BST is sorted ascending |
Example class shape:
class BSTIterator:
def __init__(self, root: Optional[TreeNode]):
...
def next(self) -> int:
...
def hasNext(self) -> bool:
...Examples
Consider this BST:
7
/ \
3 15
/ \
9 20Its inorder traversal is:
[3, 7, 9, 15, 20]Example calls:
iterator = BSTIterator(root)
iterator.next() # 3
iterator.next() # 7
iterator.hasNext() # True
iterator.next() # 9
iterator.hasNext() # True
iterator.next() # 15
iterator.hasNext() # True
iterator.next() # 20
iterator.hasNext() # FalseFirst Thought: Flatten the Tree
The simplest solution is to perform the full inorder traversal during initialization.
Store all values in a list:
self.values = [3, 7, 9, 15, 20]Then next() just returns the next list element.
class BSTIterator:
def __init__(self, root: Optional[TreeNode]):
self.values = []
self.index = 0
def inorder(node):
if node is None:
return
inorder(node.left)
self.values.append(node.val)
inorder(node.right)
inorder(root)
def next(self) -> int:
value = self.values[self.index]
self.index += 1
return value
def hasNext(self) -> bool:
return self.index < len(self.values)This works.
But it stores every value up front, so it uses O(n) space.
We can do better by simulating inorder traversal lazily.
Key Insight
Inorder traversal visits:
left subtree -> node -> right subtreeThe next smallest node is always the leftmost unvisited node.
So we keep a stack of nodes whose left side has already been prepared.
At initialization, push the whole left path from the root:
7 -> 3The top of the stack is 3, the smallest value.
When next() pops a node:
- Return that node’s value.
- If the node has a right child, push the left path of that right child.
This exactly simulates recursive inorder traversal.
Algorithm
Maintain:
self.stackDefine a helper:
_push_left(node)It pushes node, then node.left, then node.left.left, until None.
Constructor:
- Create an empty stack.
- Push the left path from
root.
next():
- Pop the top node.
- Save its value.
- Push the left path from its right child.
- Return the saved value.
hasNext():
- Return whether the stack is non-empty.
Walkthrough
Use:
7
/ \
3 15
/ \
9 20Initialization pushes the left path:
stack = [7, 3]Top is 3.
Call next():
pop 3
return 3Node 3 has no right child.
Stack:
[7]Call next():
pop 7
return 7Node 7 has right child 15.
Push left path from 15:
15 -> 9Stack:
[15, 9]Call next():
pop 9
return 9Node 9 has no right child.
Stack:
[15]Call next():
pop 15
return 15Node 15 has right child 20.
Push left path from 20:
20Stack:
[20]Call next():
pop 20
return 20Stack becomes empty.
Now:
hasNext() == FalseCorrectness
The stack invariant is:
The top of the stack is always the next unvisited node in inorder order.
During initialization, the algorithm pushes the left path from the root. The leftmost node is the smallest node in the BST, and it becomes the top of the stack.
When next() pops a node, that node is returned. Its left subtree has already been fully handled because the node reached the top only after all nodes to its left were processed.
After returning a node, the next inorder nodes from its right subtree must begin at the leftmost node of that right subtree. The algorithm pushes exactly that left path.
Therefore, after every next() call, the stack invariant is restored.
hasNext() returns true exactly when the stack has at least one unvisited node. Therefore, the iterator returns all BST values in ascending order.
Complexity
Let h be the height of the tree.
| Operation | Time | Why |
|---|---|---|
| Constructor | O(h) | Pushes the initial left path |
next() | Amortized O(1) | Each node is pushed and popped once total |
hasNext() | O(1) | Checks whether the stack is empty |
| Metric | Value | Why |
|---|---|---|
| Space | O(h) | The stack stores at most one root-to-leaf path |
A single next() call can push several nodes, so its worst-case time is O(h). Across all calls, each node is pushed once and popped once, giving amortized O(1) per next().
Implementation
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class BSTIterator:
def __init__(self, root: Optional[TreeNode]):
self.stack = []
self._push_left(root)
def _push_left(self, node: Optional[TreeNode]) -> None:
while node:
self.stack.append(node)
node = node.left
def next(self) -> int:
node = self.stack.pop()
if node.right:
self._push_left(node.right)
return node.val
def hasNext(self) -> bool:
return len(self.stack) > 0Code Explanation
The constructor initializes an empty stack:
self.stack = []Then it prepares the first smallest node:
self._push_left(root)The helper walks down the left chain:
while node:
self.stack.append(node)
node = node.leftThis places the smallest available node at the top of the stack.
In next(), we pop that node:
node = self.stack.pop()If the node has a right subtree, the next values from that subtree start at its leftmost node:
if node.right:
self._push_left(node.right)Then return the popped value:
return node.valThe hasNext() method checks whether any prepared node remains:
return len(self.stack) > 0Testing
from typing import Optional
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def collect(iterator):
result = []
while iterator.hasNext():
result.append(iterator.next())
return result
def run_tests():
root = TreeNode(
7,
TreeNode(3),
TreeNode(15, TreeNode(9), TreeNode(20)),
)
iterator = BSTIterator(root)
assert iterator.next() == 3
assert iterator.next() == 7
assert iterator.hasNext() is True
assert iterator.next() == 9
assert iterator.hasNext() is True
assert iterator.next() == 15
assert iterator.hasNext() is True
assert iterator.next() == 20
assert iterator.hasNext() is False
root = TreeNode(1)
assert collect(BSTIterator(root)) == [1]
root = TreeNode(2, TreeNode(1), TreeNode(3))
assert collect(BSTIterator(root)) == [1, 2, 3]
root = TreeNode(3, TreeNode(2, TreeNode(1)), None)
assert collect(BSTIterator(root)) == [1, 2, 3]
root = TreeNode(1, None, TreeNode(2, None, TreeNode(3)))
assert collect(BSTIterator(root)) == [1, 2, 3]
print("all tests passed")
run_tests()| Test | Why |
|---|---|
[7, 3, 15, None, None, 9, 20] | Standard example |
| Single node | Smallest tree |
| Balanced BST | Normal inorder traversal |
| Left-skewed BST | Stack starts with many left nodes |
| Right-skewed BST | Each next() pushes one right child |