# LeetCode 230: Kth Smallest Element in a BST

## Problem Restatement

We are given the `root` of a binary search tree and an integer `k`.

We need to return the `k`th smallest value among all node values in the tree.

The counting is 1-indexed. So:

```text
k = 1 means the smallest value
k = 2 means the second smallest value
k = 3 means the third smallest value
```

LeetCode states that `1 <= k <= n <= 10^4`, where `n` is the number of nodes in the tree. Node values are between `0` and `10^4`. The follow-up asks how to optimize when the BST is modified often and kth-smallest queries are frequent.

## Input and Output

| Item | Meaning |
|---|---|
| Input | `root` of a binary search tree and integer `k` |
| Output | The `k`th smallest node value |
| Counting | 1-indexed |
| Constraint | `1 <= k <= number of nodes` |

Function shape:

```python
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        ...
```

## Examples

Example 1:

```text
Input:  root = [3,1,4,null,2], k = 1
Output: 1
```

The tree is:

```text
    3
   / \
  1   4
   \
    2
```

The values in sorted order are:

```text
[1, 2, 3, 4]
```

The first smallest value is:

```text
1
```

Example 2:

```text
Input:  root = [5,3,6,2,4,null,null,1], k = 3
Output: 3
```

The tree is:

```text
        5
       / \
      3   6
     / \
    2   4
   /
  1
```

The values in sorted order are:

```text
[1, 2, 3, 4, 5, 6]
```

The third smallest value is:

```text
3
```

## First Thought: Collect and Sort

A direct solution is:

1. Traverse the whole tree.
2. Store every value in an array.
3. Sort the array.
4. Return `arr[k - 1]`.

```python
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        values = []

        def dfs(node):
            if node is None:
                return

            values.append(node.val)
            dfs(node.left)
            dfs(node.right)

        dfs(root)
        values.sort()

        return values[k - 1]
```

This works, but it wastes the binary search tree property.

A BST already stores values in an ordered structure.

## Key Insight

In a binary search tree:

```text
left subtree values < root value < right subtree values
```

So an inorder traversal visits values in ascending order.

Inorder traversal means:

```text
left subtree
current node
right subtree
```

Therefore, if we perform inorder traversal and count visited nodes, the `k`th visited node is the answer.

We do not need to sort anything.

We also do not need to visit the whole tree if we find the answer early.

## Algorithm

Use iterative inorder traversal with a stack.

Maintain:

| Variable | Meaning |
|---|---|
| `stack` | Nodes waiting to be visited |
| `cur` | Current node |
| `k` | Number of remaining nodes before the answer |

Steps:

1. Start with `cur = root`.
2. Move left as far as possible, pushing nodes into `stack`.
3. Pop one node from `stack`.
4. This popped node is the next smallest value.
5. Decrease `k`.
6. If `k == 0`, return this node's value.
7. Move to the popped node's right child.
8. Repeat.

## Correctness

The algorithm performs an inorder traversal of the BST.

For any node, all values in its left subtree are smaller than the node's value, and all values in its right subtree are larger than the node's value.

So visiting left subtree first, then the node, then the right subtree visits values in strictly increasing order.

The stack simulates this traversal.

Each time we pop from the stack, we visit the next smallest unvisited node.

The first popped node is the smallest value.

The second popped node is the second smallest value.

In general, the `k`th popped node is the `k`th smallest value.

The algorithm decreases `k` exactly once per visited node. When `k` becomes `0`, the current node is exactly the original `k`th visited node in ascending order.

Therefore, returning `cur.val` at that moment is correct.

## Complexity

| Metric | Value | Why |
|---|---|---|
| Time | `O(h + k)` | We walk down the tree height, then visit `k` nodes |
| Space | `O(h)` | The stack stores at most one path from root to leaf |

Here, `h` is the tree height.

In the worst case, `h = O(n)` for a skewed tree.

For a balanced tree, `h = O(log n)`.

## Implementation

```python
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        stack = []
        cur = root

        while cur is not None or stack:
            while cur is not None:
                stack.append(cur)
                cur = cur.left

            cur = stack.pop()
            k -= 1

            if k == 0:
                return cur.val

            cur = cur.right
```

## Code Explanation

We start at the root:

```python
stack = []
cur = root
```

The outer loop continues while there are still nodes to process:

```python
while cur is not None or stack:
```

Then we go as far left as possible:

```python
while cur is not None:
    stack.append(cur)
    cur = cur.left
```

This brings us to the smallest unvisited node.

Then we pop from the stack:

```python
cur = stack.pop()
```

This node is the next value in ascending order.

So we count it:

```python
k -= 1
```

If this is the `k`th smallest value, return it:

```python
if k == 0:
    return cur.val
```

Otherwise, move to the right subtree:

```python
cur = cur.right
```

The right subtree contains the next larger values.

## Follow-Up: Frequent Updates and Queries

If the BST is modified often and kth-smallest queries are frequent, plain inorder traversal may be too slow.

A better design is to store subtree sizes in each node.

For each node:

```text
node.size = number of nodes in the subtree rooted at node
```

Then kth-smallest can be found like this:

1. Let `left_size` be the size of the left subtree.
2. If `k == left_size + 1`, the current node is the answer.
3. If `k <= left_size`, go left.
4. Otherwise, go right with `k = k - left_size - 1`.

This gives `O(h)` query time.

If the tree is balanced, this becomes `O(log n)`.

Insert and delete operations must update subtree sizes along the changed path.

## Testing

```python
from collections import deque
from typing import Optional

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def build_tree(values):
    if not values:
        return None

    root = TreeNode(values[0])
    q = deque([root])
    i = 1

    while q and i < len(values):
        node = q.popleft()

        if i < len(values) and values[i] is not None:
            node.left = TreeNode(values[i])
            q.append(node.left)
        i += 1

        if i < len(values) and values[i] is not None:
            node.right = TreeNode(values[i])
            q.append(node.right)
        i += 1

    return root
```

```python
def run_tests():
    s = Solution()

    root = build_tree([3, 1, 4, None, 2])
    assert s.kthSmallest(root, 1) == 1

    root = build_tree([5, 3, 6, 2, 4, None, None, 1])
    assert s.kthSmallest(root, 3) == 3

    root = build_tree([2, 1, 3])
    assert s.kthSmallest(root, 2) == 2

    root = build_tree([1])
    assert s.kthSmallest(root, 1) == 1

    root = build_tree([4, 2, 5, 1, 3])
    assert s.kthSmallest(root, 5) == 5

    print("all tests passed")

run_tests()
```

| Test | Why |
|---|---|
| `[3,1,4,null,2]`, `k = 1` | Smallest value is deep enough to require left traversal |
| `[5,3,6,2,4,null,null,1]`, `k = 3` | Official-style medium tree |
| `[2,1,3]`, `k = 2` | Root is the answer |
| `[1]`, `k = 1` | Single-node tree |
| `[4,2,5,1,3]`, `k = 5` | Largest value |

