Skip to content

Distributed Sample Sort

Sort data across machines by sampling keys, choosing splitters, redistributing records into ordered buckets, and sorting buckets locally.

Distributed sample sort sorts a dataset spread across multiple machines. It uses sampled keys to choose global splitters. These splitters divide the key space into ordered buckets. Each machine sends records to the bucket owner, each bucket is sorted locally, and the sorted buckets form the final output.

The main purpose of sampling is load balance. Good splitters make every worker receive roughly the same amount of data.

Problem

Given nn records distributed across pp workers, sort all records by key in nondecreasing order.

The output may remain distributed as pp sorted partitions.

Algorithm

distributed_sample_sort(workers):
    each worker samples local keys
    gather all samples
    sort samples
    choose p - 1 splitters

    broadcast splitters to all workers

    each worker partitions local records into p buckets

    all_to_all exchange buckets

    each worker sorts received bucket

    return workers in splitter order

The splitters satisfy:

s1s2sp1 s_1 \le s_2 \le \cdots \le s_{p-1}

Worker 00 receives the smallest key range. Worker p1p - 1 receives the largest key range.

Bucket Assignment

Each key is assigned using binary search over the splitters.

bucket_id(key, splitters):
    return upper_bound(splitters, key)

This gives a value from 00 to p1p - 1.

Communication

After local partitioning, workers perform an all to all exchange.

for each source worker u:
    for each target worker v:
        send bucket[u][v] to worker v

Each target worker receives records from all workers for one key interval.

Complexity

measurevalue
local samplingO(n/p)O(n/p) per worker
local partitioningO((n/p)logp)O((n/p)\log p)
communication volumeO(n)O(n) records
local sortingexpected O((n/p)log(n/p))O((n/p)\log(n/p))
output partitionspp sorted ranges

The wall clock cost is often dominated by network exchange and skew.

Correctness

Splitters divide the global key space into ordered intervals. All records assigned to worker ii are less than or equal to all records assigned to worker jj when i<ji < j. Each worker sorts its own received records. Therefore, the workers hold sorted partitions in global key order.

Reading worker outputs from 00 to p1p - 1 gives the complete sorted dataset.

Practical Considerations

  • Oversampling improves load balance.
  • Skewed or duplicate heavy data may overload one worker.
  • All to all exchange can stress the network.
  • Compression may reduce transfer cost.
  • Local sort can use radix sort, quicksort, merge sort, or external sort.
  • Output is usually partitioned, not physically concatenated.

When to Use

Use distributed sample sort when:

  • data is too large for one machine
  • many workers are available
  • sorted distributed partitions are acceptable
  • sampling can approximate the key distribution

Avoid it when communication cost dominates or the key distribution causes severe bucket skew.

Implementation Sketch

local_sample(records, sample_count):
    return evenly_spaced_sample(records, sample_count)
choose_splitters(samples, p):
    sort samples
    splitters = []

    for i from 1 to p - 1:
        splitters.append(samples[i * length(samples) / p])

    return splitters
worker_sort(local_records, splitters):
    buckets = array of p empty lists

    for record in local_records:
        b = upper_bound(splitters, key(record))
        buckets[b].append(record)

    send bucket b to worker b

    received = receive buckets from all workers
    sort received by key

    return received

Simplified Python Model

from bisect import bisect_right
from collections import defaultdict

def distributed_sample_sort(records_by_worker, splitters):
    p = len(records_by_worker)
    inbox = [list() for _ in range(p)]

    for local_records in records_by_worker:
        for record in local_records:
            key = record[0]
            b = bisect_right(splitters, key)
            inbox[b].append(record)

    for b in range(p):
        inbox[b].sort(key=lambda x: x[0])

    return inbox