Skip to content

Parallel Radix Sort

Sort integer keys by processing fixed width digit groups in parallel using counting and prefix sums.

Parallel radix sort sorts fixed width keys by processing one digit group at a time. A digit group may be a byte, a few bits, or another small radix unit. Each pass groups elements by the current digit while preserving the order needed by the radix variant.

For integer keys, parallel radix sort is often faster than comparison sorting because it uses key digits directly. Its main costs are memory bandwidth, histogram construction, prefix sums, and scattering elements into output positions.

Problem

Given an array AA of nn fixed width integer keys, sort the keys in nondecreasing order using parallel workers.

Algorithm

For least significant digit radix sort, process digit groups from low to high. Each pass builds digit counts, computes output offsets, and scatters keys into their next positions.

parallel_radix_sort(A, bits_per_pass):
    B = new array of length A

    for shift from 0 to word_bits - 1 step bits_per_pass:
        parallel build local histograms for digit(A[i], shift)
        reduce local histograms into global counts
        compute prefix sums over global counts
        compute per-worker offsets from local histograms

        parallel scatter each A[i] into B using its digit offset

        swap A and B

    return A

The digit extracted at a pass is usually:

digit(x, shift, mask):
    return (x >> shift) & mask

For example, with 88 bits per pass, the radix is:

28=256 2^8 = 256

Local Histograms

Each worker counts digits for its own slice.

build_local_histogram(A, lo, hi, shift, mask):
    count = array of zeros with length radix

    for i from lo to hi - 1:
        d = digit(A[i], shift, mask)
        count[d] += 1

    return count

Local histograms avoid contention. After all workers finish, the local counts are reduced into global counts.

Prefix Sums

Prefix sums convert digit counts into starting positions.

If

count[d] count[d]

is the number of keys with digit dd, then

offset[d]=j<dcount[j] offset[d] = \sum_{j < d} count[j]

gives the first output position for digit dd.

Complexity

Let ww be the number of key bits and rr be the number of bits processed per pass.

measurevalue
number of passesw/rw / r
work per passO(n+2r)O(n + 2^r)
total workO((w/r)(n+2r))O((w/r)(n + 2^r))
extra spaceO(n+p2r)O(n + p2^r)

Here pp is the number of workers. The term p2rp2^r comes from local histograms.

Correctness

Each pass stably groups keys by one digit. In least significant digit radix sort, stability preserves the ordering established by earlier lower digit passes. After all digit groups have been processed, keys are ordered by every digit from most significant to least significant, so the array is sorted.

Practical Considerations

  • Use local histograms to avoid atomic increments.
  • Use prefix sums to assign disjoint output ranges.
  • Choose radix size to balance fewer passes against larger histograms.
  • Byte radix, with r=8r = 8, is a common practical choice.
  • Memory bandwidth often dominates runtime.
  • Signed integers need a transform so negative keys order before nonnegative keys.

When to Use

Use parallel radix sort when:

  • keys are integers or fixed width byte strings
  • the input is large
  • memory bandwidth is high
  • stable sorting by key is useful
  • comparison cost should be avoided

Avoid it when keys are complex comparison objects, memory is tight, or the input is too small to amortize parallel overhead.

Implementation (Go, simplified)

func ParallelRadixSort(a []uint32, workers int) []uint32 {
	if len(a) <= 1 {
		return append([]uint32(nil), a...)
	}
	if workers < 1 {
		workers = 1
	}

	out := append([]uint32(nil), a...)
	buf := make([]uint32, len(out))

	const bits = 8
	const radix = 1 << bits
	const mask = radix - 1

	for shift := 0; shift < 32; shift += bits {
		local := make([][radix]int, workers)

		done := make(chan struct{}, workers)

		for w := 0; w < workers; w++ {
			lo := w * len(out) / workers
			hi := (w + 1) * len(out) / workers

			go func(w, lo, hi int) {
				for i := lo; i < hi; i++ {
					d := (out[i] >> shift) & mask
					local[w][d]++
				}
				done <- struct{}{}
			}(w, lo, hi)
		}

		for w := 0; w < workers; w++ {
			<-done
		}

		var global [radix]int
		for d := 0; d < radix; d++ {
			for w := 0; w < workers; w++ {
				global[d] += local[w][d]
			}
		}

		var start [radix]int
		sum := 0
		for d := 0; d < radix; d++ {
			start[d] = sum
			sum += global[d]
		}

		offsets := make([][radix]int, workers)
		for d := 0; d < radix; d++ {
			pos := start[d]
			for w := 0; w < workers; w++ {
				offsets[w][d] = pos
				pos += local[w][d]
			}
		}

		for w := 0; w < workers; w++ {
			lo := w * len(out) / workers
			hi := (w + 1) * len(out) / workers

			go func(w, lo, hi int) {
				pos := offsets[w]
				for i := lo; i < hi; i++ {
					d := (out[i] >> shift) & mask
					buf[pos[d]] = out[i]
					pos[d]++
				}
				done <- struct{}{}
			}(w, lo, hi)
		}

		for w := 0; w < workers; w++ {
			<-done
		}

		out, buf = buf, out
	}

	return out
}