Skip to content

LeetCode 230: Kth Smallest Element in a BST

A clear explanation of finding the kth smallest value in a binary search tree using inorder traversal.

Problem Restatement

We are given the root of a binary search tree and an integer k.

We need to return the kth smallest value among all node values in the tree.

The counting is 1-indexed. So:

k = 1 means the smallest value
k = 2 means the second smallest value
k = 3 means the third smallest value

LeetCode states that 1 <= k <= n <= 10^4, where n is the number of nodes in the tree. Node values are between 0 and 10^4. The follow-up asks how to optimize when the BST is modified often and kth-smallest queries are frequent.

Input and Output

ItemMeaning
Inputroot of a binary search tree and integer k
OutputThe kth smallest node value
Counting1-indexed
Constraint1 <= k <= number of nodes

Function shape:

class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        ...

Examples

Example 1:

Input:  root = [3,1,4,null,2], k = 1
Output: 1

The tree is:

    3
   / \
  1   4
   \
    2

The values in sorted order are:

[1, 2, 3, 4]

The first smallest value is:

1

Example 2:

Input:  root = [5,3,6,2,4,null,null,1], k = 3
Output: 3

The tree is:

        5
       / \
      3   6
     / \
    2   4
   /
  1

The values in sorted order are:

[1, 2, 3, 4, 5, 6]

The third smallest value is:

3

First Thought: Collect and Sort

A direct solution is:

  1. Traverse the whole tree.
  2. Store every value in an array.
  3. Sort the array.
  4. Return arr[k - 1].
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        values = []

        def dfs(node):
            if node is None:
                return

            values.append(node.val)
            dfs(node.left)
            dfs(node.right)

        dfs(root)
        values.sort()

        return values[k - 1]

This works, but it wastes the binary search tree property.

A BST already stores values in an ordered structure.

Key Insight

In a binary search tree:

left subtree values < root value < right subtree values

So an inorder traversal visits values in ascending order.

Inorder traversal means:

left subtree
current node
right subtree

Therefore, if we perform inorder traversal and count visited nodes, the kth visited node is the answer.

We do not need to sort anything.

We also do not need to visit the whole tree if we find the answer early.

Algorithm

Use iterative inorder traversal with a stack.

Maintain:

VariableMeaning
stackNodes waiting to be visited
curCurrent node
kNumber of remaining nodes before the answer

Steps:

  1. Start with cur = root.
  2. Move left as far as possible, pushing nodes into stack.
  3. Pop one node from stack.
  4. This popped node is the next smallest value.
  5. Decrease k.
  6. If k == 0, return this node’s value.
  7. Move to the popped node’s right child.
  8. Repeat.

Correctness

The algorithm performs an inorder traversal of the BST.

For any node, all values in its left subtree are smaller than the node’s value, and all values in its right subtree are larger than the node’s value.

So visiting left subtree first, then the node, then the right subtree visits values in strictly increasing order.

The stack simulates this traversal.

Each time we pop from the stack, we visit the next smallest unvisited node.

The first popped node is the smallest value.

The second popped node is the second smallest value.

In general, the kth popped node is the kth smallest value.

The algorithm decreases k exactly once per visited node. When k becomes 0, the current node is exactly the original kth visited node in ascending order.

Therefore, returning cur.val at that moment is correct.

Complexity

MetricValueWhy
TimeO(h + k)We walk down the tree height, then visit k nodes
SpaceO(h)The stack stores at most one path from root to leaf

Here, h is the tree height.

In the worst case, h = O(n) for a skewed tree.

For a balanced tree, h = O(log n).

Implementation

class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        stack = []
        cur = root

        while cur is not None or stack:
            while cur is not None:
                stack.append(cur)
                cur = cur.left

            cur = stack.pop()
            k -= 1

            if k == 0:
                return cur.val

            cur = cur.right

Code Explanation

We start at the root:

stack = []
cur = root

The outer loop continues while there are still nodes to process:

while cur is not None or stack:

Then we go as far left as possible:

while cur is not None:
    stack.append(cur)
    cur = cur.left

This brings us to the smallest unvisited node.

Then we pop from the stack:

cur = stack.pop()

This node is the next value in ascending order.

So we count it:

k -= 1

If this is the kth smallest value, return it:

if k == 0:
    return cur.val

Otherwise, move to the right subtree:

cur = cur.right

The right subtree contains the next larger values.

Follow-Up: Frequent Updates and Queries

If the BST is modified often and kth-smallest queries are frequent, plain inorder traversal may be too slow.

A better design is to store subtree sizes in each node.

For each node:

node.size = number of nodes in the subtree rooted at node

Then kth-smallest can be found like this:

  1. Let left_size be the size of the left subtree.
  2. If k == left_size + 1, the current node is the answer.
  3. If k <= left_size, go left.
  4. Otherwise, go right with k = k - left_size - 1.

This gives O(h) query time.

If the tree is balanced, this becomes O(log n).

Insert and delete operations must update subtree sizes along the changed path.

Testing

from collections import deque
from typing import Optional

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def build_tree(values):
    if not values:
        return None

    root = TreeNode(values[0])
    q = deque([root])
    i = 1

    while q and i < len(values):
        node = q.popleft()

        if i < len(values) and values[i] is not None:
            node.left = TreeNode(values[i])
            q.append(node.left)
        i += 1

        if i < len(values) and values[i] is not None:
            node.right = TreeNode(values[i])
            q.append(node.right)
        i += 1

    return root
def run_tests():
    s = Solution()

    root = build_tree([3, 1, 4, None, 2])
    assert s.kthSmallest(root, 1) == 1

    root = build_tree([5, 3, 6, 2, 4, None, None, 1])
    assert s.kthSmallest(root, 3) == 3

    root = build_tree([2, 1, 3])
    assert s.kthSmallest(root, 2) == 2

    root = build_tree([1])
    assert s.kthSmallest(root, 1) == 1

    root = build_tree([4, 2, 5, 1, 3])
    assert s.kthSmallest(root, 5) == 5

    print("all tests passed")

run_tests()
TestWhy
[3,1,4,null,2], k = 1Smallest value is deep enough to require left traversal
[5,3,6,2,4,null,null,1], k = 3Official-style medium tree
[2,1,3], k = 2Root is the answer
[1], k = 1Single-node tree
[4,2,5,1,3], k = 5Largest value