# Parallel Radix Sort

# Parallel Radix Sort

Parallel radix sort sorts fixed width keys by processing one digit group at a time. A digit group may be a byte, a few bits, or another small radix unit. Each pass groups elements by the current digit while preserving the order needed by the radix variant.

For integer keys, parallel radix sort is often faster than comparison sorting because it uses key digits directly. Its main costs are memory bandwidth, histogram construction, prefix sums, and scattering elements into output positions.

## Problem

Given an array $A$ of $n$ fixed width integer keys, sort the keys in nondecreasing order using parallel workers.

## Algorithm

For least significant digit radix sort, process digit groups from low to high. Each pass builds digit counts, computes output offsets, and scatters keys into their next positions.

```text id="h5k3tw"
parallel_radix_sort(A, bits_per_pass):
    B = new array of length A

    for shift from 0 to word_bits - 1 step bits_per_pass:
        parallel build local histograms for digit(A[i], shift)
        reduce local histograms into global counts
        compute prefix sums over global counts
        compute per-worker offsets from local histograms

        parallel scatter each A[i] into B using its digit offset

        swap A and B

    return A
```

The digit extracted at a pass is usually:

```text id="nfa9i3"
digit(x, shift, mask):
    return (x >> shift) & mask
```

For example, with $8$ bits per pass, the radix is:

$$
2^8 = 256
$$

## Local Histograms

Each worker counts digits for its own slice.

```text id="cb1zdi"
build_local_histogram(A, lo, hi, shift, mask):
    count = array of zeros with length radix

    for i from lo to hi - 1:
        d = digit(A[i], shift, mask)
        count[d] += 1

    return count
```

Local histograms avoid contention. After all workers finish, the local counts are reduced into global counts.

## Prefix Sums

Prefix sums convert digit counts into starting positions.

If

$$
count[d]
$$

is the number of keys with digit $d$, then

$$
offset[d] = \sum_{j < d} count[j]
$$

gives the first output position for digit $d$.

## Complexity

Let $w$ be the number of key bits and $r$ be the number of bits processed per pass.

| measure          | value               |
| ---------------- | ------------------- |
| number of passes | $w / r$             |
| work per pass    | $O(n + 2^r)$        |
| total work       | $O((w/r)(n + 2^r))$ |
| extra space      | $O(n + p2^r)$       |

Here $p$ is the number of workers. The term $p2^r$ comes from local histograms.

## Correctness

Each pass stably groups keys by one digit. In least significant digit radix sort, stability preserves the ordering established by earlier lower digit passes. After all digit groups have been processed, keys are ordered by every digit from most significant to least significant, so the array is sorted.

## Practical Considerations

* Use local histograms to avoid atomic increments.
* Use prefix sums to assign disjoint output ranges.
* Choose radix size to balance fewer passes against larger histograms.
* Byte radix, with $r = 8$, is a common practical choice.
* Memory bandwidth often dominates runtime.
* Signed integers need a transform so negative keys order before nonnegative keys.

## When to Use

Use parallel radix sort when:

* keys are integers or fixed width byte strings
* the input is large
* memory bandwidth is high
* stable sorting by key is useful
* comparison cost should be avoided

Avoid it when keys are complex comparison objects, memory is tight, or the input is too small to amortize parallel overhead.

## Implementation (Go, simplified)

```go id="u7g1nq"
func ParallelRadixSort(a []uint32, workers int) []uint32 {
	if len(a) <= 1 {
		return append([]uint32(nil), a...)
	}
	if workers < 1 {
		workers = 1
	}

	out := append([]uint32(nil), a...)
	buf := make([]uint32, len(out))

	const bits = 8
	const radix = 1 << bits
	const mask = radix - 1

	for shift := 0; shift < 32; shift += bits {
		local := make([][radix]int, workers)

		done := make(chan struct{}, workers)

		for w := 0; w < workers; w++ {
			lo := w * len(out) / workers
			hi := (w + 1) * len(out) / workers

			go func(w, lo, hi int) {
				for i := lo; i < hi; i++ {
					d := (out[i] >> shift) & mask
					local[w][d]++
				}
				done <- struct{}{}
			}(w, lo, hi)
		}

		for w := 0; w < workers; w++ {
			<-done
		}

		var global [radix]int
		for d := 0; d < radix; d++ {
			for w := 0; w < workers; w++ {
				global[d] += local[w][d]
			}
		}

		var start [radix]int
		sum := 0
		for d := 0; d < radix; d++ {
			start[d] = sum
			sum += global[d]
		}

		offsets := make([][radix]int, workers)
		for d := 0; d < radix; d++ {
			pos := start[d]
			for w := 0; w < workers; w++ {
				offsets[w][d] = pos
				pos += local[w][d]
			}
		}

		for w := 0; w < workers; w++ {
			lo := w * len(out) / workers
			hi := (w + 1) * len(out) / workers

			go func(w, lo, hi int) {
				pos := offsets[w]
				for i := lo; i < hi; i++ {
					d := (out[i] >> shift) & mask
					buf[pos[d]] = out[i]
					pos[d]++
				}
				done <- struct{}{}
			}(w, lo, hi)
		}

		for w := 0; w < workers; w++ {
			<-done
		}

		out, buf = buf, out
	}

	return out
}
```

