Skip to content

Ball Tree Search

Search for nearest neighbors in metric space using hierarchical clustering with hyperspheres.

Ball tree search organizes points into a hierarchy of nested hyperspheres. Each node represents a ball defined by a center point and a radius, covering all points in its subtree.

It is commonly used for nearest neighbor search in metric spaces where distance computations are expensive.

Problem

Given a set of points in a metric space and a query point qq, find the nearest neighbor:

argminpSd(p,q) \arg\min_{p \in S} d(p, q)

Structure

Each node stores:

fieldmeaning
centerrepresentative point
radiuscovering radius
leftleft subset of points
rightright subset of points

All points in a subtree lie within a ball:

d(p,center)radius d(p, center) \le radius

Algorithm

The search maintains the best known distance and prunes nodes whose balls cannot improve the answer.

ball_tree_search(node, q, best):
    if node is null:
        return best

    if distance(q, node.center) - node.radius >= best:
        return best

    if node.is_leaf:
        for p in node.points:
            best = min(best, distance(q, p))
        return best

    if distance(q, node.left.center) < distance(q, node.right.center):
        best = ball_tree_search(node.left, q, best)
        best = ball_tree_search(node.right, q, best)
    else:
        best = ball_tree_search(node.right, q, best)
        best = ball_tree_search(node.left, q, best)

    return best

Example

Points in 2D:

point
(1, 1)
(2, 2)
(8, 8)
(9, 9)

Root ball:

  • center: (5, 5)
  • radius: covers all points

Query:

q=(2,3) q = (2, 3)

Closest point found:

(2,2) (2, 2)

Correctness

Each node defines a metric ball that fully contains all points in its subtree. If the minimum possible distance from the query point to any point in the ball exceeds the current best distance, no point in that subtree can improve the result, so pruning is safe.

The algorithm explores all nodes that could contain a closer point and checks all leaf points explicitly. Therefore, the minimum distance found is globally optimal.

Complexity

Let nn be number of points.

casetime
average nearest neighborO(logn)O(\log n)
worst caseO(n)O(n)

Space complexity:

O(n) O(n)

When to Use

Ball trees are useful when:

  • data is in metric spaces (not necessarily Euclidean grids)
  • distance computations are expensive
  • nearest neighbor search is frequent
  • clustering structure is meaningful

Compared to k-d trees, ball trees perform better in higher dimensions where axis aligned splits become less effective.

Implementation

import math

def dist(a, b):
    return math.sqrt(sum((x - y) ** 2 for x, y in zip(a, b)))

class Node:
    def __init__(self, center, radius):
        self.center = center
        self.radius = radius
        self.left = None
        self.right = None
        self.points = None

def ball_tree_search(node, q, best=float("inf")):
    if node is None:
        return best

    if dist(q, node.center) - node.radius >= best:
        return best

    if node.points is not None:
        for p in node.points:
            best = min(best, dist(q, p))
        return best

    if node.left and node.right:
        dl = dist(q, node.left.center)
        dr = dist(q, node.right.center)

        if dl < dr:
            best = ball_tree_search(node.left, q, best)
            best = ball_tree_search(node.right, q, best)
        else:
            best = ball_tree_search(node.right, q, best)
            best = ball_tree_search(node.left, q, best)

    return best
import "math"

type Point []float64

func dist(a, b Point) float64 {
	sum := 0.0
	for i := range a {
		d := a[i] - b[i]
		sum += d * d
	}
	return math.Sqrt(sum)
}

type Node struct {
	Center Point
	Radius float64
	Left   *Node
	Right  *Node
	Points []Point
}

func BallTreeSearch(node *Node, q Point, best float64) float64 {
	if node == nil {
		return best
	}

	if dist(q, node.Center)-node.Radius >= best {
		return best
	}

	if node.Points != nil {
		for _, p := range node.Points {
			d := dist(q, p)
			if d < best {
				best = d
			}
		}
		return best
	}

	if node.Left != nil && node.Right != nil {
		dl := dist(q, node.Left.Center)
		dr := dist(q, node.Right.Center)

		if dl < dr {
			best = BallTreeSearch(node.Left, q, best)
			best = BallTreeSearch(node.Right, q, best)
		} else {
			best = BallTreeSearch(node.Right, q, best)
			best = BallTreeSearch(node.Left, q, best)
		}
	}

	return best
}