# Distributed Sample Sort

# Distributed Sample Sort

Distributed sample sort sorts a dataset spread across multiple machines. It uses sampled keys to choose global splitters. These splitters divide the key space into ordered buckets. Each machine sends records to the bucket owner, each bucket is sorted locally, and the sorted buckets form the final output.

The main purpose of sampling is load balance. Good splitters make every worker receive roughly the same amount of data.

## Problem

Given $n$ records distributed across $p$ workers, sort all records by key in nondecreasing order.

The output may remain distributed as $p$ sorted partitions.

## Algorithm

```text id="t93b7k"
distributed_sample_sort(workers):
    each worker samples local keys
    gather all samples
    sort samples
    choose p - 1 splitters

    broadcast splitters to all workers

    each worker partitions local records into p buckets

    all_to_all exchange buckets

    each worker sorts received bucket

    return workers in splitter order
```

The splitters satisfy:

$$
s_1 \le s_2 \le \cdots \le s_{p-1}
$$

Worker $0$ receives the smallest key range. Worker $p - 1$ receives the largest key range.

## Bucket Assignment

Each key is assigned using binary search over the splitters.

```text id="is34vq"
bucket_id(key, splitters):
    return upper_bound(splitters, key)
```

This gives a value from $0$ to $p - 1$.

## Communication

After local partitioning, workers perform an all to all exchange.

```text id="j41mfi"
for each source worker u:
    for each target worker v:
        send bucket[u][v] to worker v
```

Each target worker receives records from all workers for one key interval.

## Complexity

| measure              | value                        |
| -------------------- | ---------------------------- |
| local sampling       | $O(n/p)$ per worker          |
| local partitioning   | $O((n/p)\log p)$             |
| communication volume | $O(n)$ records               |
| local sorting        | expected $O((n/p)\log(n/p))$ |
| output partitions    | $p$ sorted ranges            |

The wall clock cost is often dominated by network exchange and skew.

## Correctness

Splitters divide the global key space into ordered intervals. All records assigned to worker $i$ are less than or equal to all records assigned to worker $j$ when $i < j$. Each worker sorts its own received records. Therefore, the workers hold sorted partitions in global key order.

Reading worker outputs from $0$ to $p - 1$ gives the complete sorted dataset.

## Practical Considerations

* Oversampling improves load balance.
* Skewed or duplicate heavy data may overload one worker.
* All to all exchange can stress the network.
* Compression may reduce transfer cost.
* Local sort can use radix sort, quicksort, merge sort, or external sort.
* Output is usually partitioned, not physically concatenated.

## When to Use

Use distributed sample sort when:

* data is too large for one machine
* many workers are available
* sorted distributed partitions are acceptable
* sampling can approximate the key distribution

Avoid it when communication cost dominates or the key distribution causes severe bucket skew.

## Implementation Sketch

```text id="e8rvd9"
local_sample(records, sample_count):
    return evenly_spaced_sample(records, sample_count)
```

```text id="6k4lpi"
choose_splitters(samples, p):
    sort samples
    splitters = []

    for i from 1 to p - 1:
        splitters.append(samples[i * length(samples) / p])

    return splitters
```

```text id="fm9akq"
worker_sort(local_records, splitters):
    buckets = array of p empty lists

    for record in local_records:
        b = upper_bound(splitters, key(record))
        buckets[b].append(record)

    send bucket b to worker b

    received = receive buckets from all workers
    sort received by key

    return received
```

## Simplified Python Model

```python id="k2eymk"
from bisect import bisect_right
from collections import defaultdict

def distributed_sample_sort(records_by_worker, splitters):
    p = len(records_by_worker)
    inbox = [list() for _ in range(p)]

    for local_records in records_by_worker:
        for record in local_records:
            key = record[0]
            b = bisect_right(splitters, key)
            inbox[b].append(record)

    for b in range(p):
        inbox[b].sort(key=lambda x: x[0])

    return inbox
```

