# k-d Tree Search

# k-d Tree Search

k-d tree search organizes points in k-dimensional space using recursive axis aligned partitions. Each level splits the space along one coordinate axis.

It supports both exact membership queries and range or nearest neighbor search.

## Problem

Given a set of points in k-dimensional space and a query point or region, find:

* whether a point exists
* or all points in a region
* or the nearest neighbor

## Structure

Each node stores:

| field | meaning                                |
| ----- | -------------------------------------- |
| point | k-dimensional point                    |
| axis  | splitting dimension                    |
| left  | points with smaller coordinate on axis |
| right | points with larger coordinate on axis  |

The splitting axis cycles:

$$
axis = depth \bmod k
$$

## Exact Search Algorithm

Search follows binary tree logic using the splitting axis.

```text id="kd0k7q"
kd_search(node, target, depth):
    if node is null:
        return false

    if node.point == target:
        return true

    axis = depth mod k

    if target[axis] < node.point[axis]:
        return kd_search(node.left, target, depth + 1)
    else:
        return kd_search(node.right, target, depth + 1)
```

## Range Search Algorithm

Collect all points inside a hyperrectangle.

```text id="kd1m8v"
kd_range_search(node, range, depth, result):
    if node is null:
        return

    if node.point inside range:
        add node.point to result

    axis = depth mod k

    if range.min[axis] <= node.point[axis]:
        kd_range_search(node.left, range, depth + 1, result)

    if node.point[axis] <= range.max[axis]:
        kd_range_search(node.right, range, depth + 1, result)
```

## Example

2D points:

| point  |
| ------ |
| (2, 3) |
| (5, 4) |
| (9, 6) |
| (4, 7) |
| (8, 1) |
| (7, 2) |

Tree splits:

* depth 0: x axis
* depth 1: y axis
* depth 2: x axis

Search for point `(5, 4)`:

| step | axis  | action                      |
| ---- | ----- | --------------------------- |
| 1    | x     | go right or left based on x |
| 2    | y     | refine branch               |
| 3    | match | found                       |

## Correctness

Each node partitions space into two half spaces along a selected axis. Every point in the left subtree satisfies:

$$
p[axis] \le node[axis]
$$

and every point in the right subtree satisfies:

$$
p[axis] > node[axis]
$$

This invariant ensures that at each step, the algorithm eliminates all regions that cannot contain the target or valid range points.

For exact search, the recursion follows the only branch that could contain the target. For range search, both branches are explored only when they may intersect the query region.

Thus all valid points are found and no invalid points are included.

## Complexity

Let $n$ be number of points.

| operation        | average             | worst        |
| ---------------- | ------------------- | ------------ |
| search           | $O(\log n)$         | $O(n)$       |
| range query      | $O(\log n + k)$     | $O(n)$       |
| nearest neighbor | $O(\log n)$ average | $O(n)$ worst |

Space complexity:

$$
O(n)
$$

## When to Use

k-d trees are useful when:

* data is low to moderate dimensional (typically 2D to 10D)
* spatial queries are frequent
* nearest neighbor search is needed
* dataset is not extremely high dimensional

Compared to range trees, k-d trees use less memory but have weaker worst case guarantees.

## Implementation

```python id="kd2n7x"
class Node:
    def __init__(self, point, axis=0):
        self.point = point
        self.axis = axis
        self.left = None
        self.right = None

def kd_search(node, target, depth=0):
    if node is None:
        return False

    if node.point == target:
        return True

    axis = depth % len(target)

    if target[axis] < node.point[axis]:
        return kd_search(node.left, target, depth + 1)
    else:
        return kd_search(node.right, target, depth + 1)
```

```go id="kd3m8v"
type Point []int

type Node struct {
	Point Point
	Left  *Node
	Right *Node
}

func KdSearch(node *Node, target Point, depth int) bool {
	if node == nil {
		return false
	}

	match := true
	for i := range target {
		if node.Point[i] != target[i] {
			match = false
			break
		}
	}

	if match {
		return true
	}

	axis := depth % len(target)

	if target[axis] < node.Point[axis] {
		return KdSearch(node.Left, target, depth+1)
	}

	return KdSearch(node.Right, target, depth+1)
}
```

