# Ball Tree Search

# Ball Tree Search

Ball tree search organizes points into a hierarchy of nested hyperspheres. Each node represents a ball defined by a center point and a radius, covering all points in its subtree.

It is commonly used for nearest neighbor search in metric spaces where distance computations are expensive.

## Problem

Given a set of points in a metric space and a query point $q$, find the nearest neighbor:

$$
\arg\min_{p \in S} d(p, q)
$$

## Structure

Each node stores:

| field  | meaning                |
| ------ | ---------------------- |
| center | representative point   |
| radius | covering radius        |
| left   | left subset of points  |
| right  | right subset of points |

All points in a subtree lie within a ball:

$$
d(p, center) \le radius
$$

## Algorithm

The search maintains the best known distance and prunes nodes whose balls cannot improve the answer.

```text id="bt0k7q"
ball_tree_search(node, q, best):
    if node is null:
        return best

    if distance(q, node.center) - node.radius >= best:
        return best

    if node.is_leaf:
        for p in node.points:
            best = min(best, distance(q, p))
        return best

    if distance(q, node.left.center) < distance(q, node.right.center):
        best = ball_tree_search(node.left, q, best)
        best = ball_tree_search(node.right, q, best)
    else:
        best = ball_tree_search(node.right, q, best)
        best = ball_tree_search(node.left, q, best)

    return best
```

## Example

Points in 2D:

| point  |
| ------ |
| (1, 1) |
| (2, 2) |
| (8, 8) |
| (9, 9) |

Root ball:

* center: (5, 5)
* radius: covers all points

Query:

$$
q = (2, 3)
$$

Closest point found:

$$
(2, 2)
$$

## Correctness

Each node defines a metric ball that fully contains all points in its subtree. If the minimum possible distance from the query point to any point in the ball exceeds the current best distance, no point in that subtree can improve the result, so pruning is safe.

The algorithm explores all nodes that could contain a closer point and checks all leaf points explicitly. Therefore, the minimum distance found is globally optimal.

## Complexity

Let $n$ be number of points.

| case                     | time        |
| ------------------------ | ----------- |
| average nearest neighbor | $O(\log n)$ |
| worst case               | $O(n)$      |

Space complexity:

$$
O(n)
$$

## When to Use

Ball trees are useful when:

* data is in metric spaces (not necessarily Euclidean grids)
* distance computations are expensive
* nearest neighbor search is frequent
* clustering structure is meaningful

Compared to k-d trees, ball trees perform better in higher dimensions where axis aligned splits become less effective.

## Implementation

```python id="bt1m8v"
import math

def dist(a, b):
    return math.sqrt(sum((x - y) ** 2 for x, y in zip(a, b)))

class Node:
    def __init__(self, center, radius):
        self.center = center
        self.radius = radius
        self.left = None
        self.right = None
        self.points = None

def ball_tree_search(node, q, best=float("inf")):
    if node is None:
        return best

    if dist(q, node.center) - node.radius >= best:
        return best

    if node.points is not None:
        for p in node.points:
            best = min(best, dist(q, p))
        return best

    if node.left and node.right:
        dl = dist(q, node.left.center)
        dr = dist(q, node.right.center)

        if dl < dr:
            best = ball_tree_search(node.left, q, best)
            best = ball_tree_search(node.right, q, best)
        else:
            best = ball_tree_search(node.right, q, best)
            best = ball_tree_search(node.left, q, best)

    return best
```

```go id="bt2n7x"
import "math"

type Point []float64

func dist(a, b Point) float64 {
	sum := 0.0
	for i := range a {
		d := a[i] - b[i]
		sum += d * d
	}
	return math.Sqrt(sum)
}

type Node struct {
	Center Point
	Radius float64
	Left   *Node
	Right  *Node
	Points []Point
}

func BallTreeSearch(node *Node, q Point, best float64) float64 {
	if node == nil {
		return best
	}

	if dist(q, node.Center)-node.Radius >= best {
		return best
	}

	if node.Points != nil {
		for _, p := range node.Points {
			d := dist(q, p)
			if d < best {
				best = d
			}
		}
		return best
	}

	if node.Left != nil && node.Right != nil {
		dl := dist(q, node.Left.Center)
		dr := dist(q, node.Right.Center)

		if dl < dr {
			best = BallTreeSearch(node.Left, q, best)
			best = BallTreeSearch(node.Right, q, best)
		} else {
			best = BallTreeSearch(node.Right, q, best)
			best = BallTreeSearch(node.Left, q, best)
		}
	}

	return best
}
```

