# GPU Radix Sort

# GPU Radix Sort

GPU radix sort sorts fixed width integer keys by processing a small number of bits at a time. Each pass groups keys by the current digit. On GPUs, the main work is building histograms, computing prefix sums, and scattering keys into their next positions.

For integer keys, GPU radix sort is often one of the fastest general sorting methods because it avoids comparisons and exposes large amounts of parallel work.

## Problem

Given an array $A$ of $n$ fixed width integer keys stored on a GPU, sort the keys in nondecreasing order.

## Algorithm

For least significant digit radix sort, process the low bits first and move toward the high bits.

```text id="w19x6z"
gpu_radix_sort(A, bits_per_pass):
    B = new array of length A

    for shift from 0 to word_bits - 1 step bits_per_pass:
        parallel build block histograms for digit(A[i], shift)
        reduce block histograms into global counts
        compute prefix sums over global counts
        compute per-block digit offsets
        parallel scatter A into B by digit position
        swap A and B

    return A
```

The digit is usually extracted with bit operations.

```text id="m3dtqm"
digit(x, shift, mask):
    return (x >> shift) & mask
```

For example, using $4$ bits per pass gives radix $16$, while using $8$ bits per pass gives radix $256$.

## GPU Pass Structure

A practical GPU radix pass usually has three kernels.

| stage     | purpose                              |
| --------- | ------------------------------------ |
| histogram | count digit frequencies per block    |
| scan      | compute global and per-block offsets |
| scatter   | move keys into output positions      |

Small radices reduce histogram size. Large radices reduce the number of passes but increase shared memory and scan cost.

## Prefix Sums

For each digit $d$, the prefix sum gives the beginning of that digit's output range.

$$
offset[d] = \sum_{j < d} count[j]
$$

Each block receives a private subrange inside the digit range, which avoids write conflicts during scatter.

## Complexity

Let $w$ be the number of key bits and $r$ be the number of bits processed per pass.

| measure       | value               |
| ------------- | ------------------- |
| passes        | $w / r$             |
| work per pass | $O(n + 2^r)$        |
| total work    | $O((w/r)(n + 2^r))$ |
| extra memory  | $O(n + b2^r)$       |

Here $b$ is the number of GPU blocks that maintain local histograms.

## Correctness

Each pass stably groups elements by one digit. In least significant digit radix sort, stability preserves the order produced by earlier lower digit passes. After all digit positions have been processed, the keys are ordered by their full binary representation, so the array is sorted.

## Practical Considerations

* Use shared memory for block histograms when possible.
* Avoid global atomic contention by using per-block counts.
* Choose radix size based on shared memory and occupancy.
* Scatter is memory bandwidth heavy and often dominates runtime.
* Signed integers need key transformation so negative values sort before nonnegative values.
* Key value pairs require moving the associated payload with each key.

## When to Use

Use GPU radix sort when:

* keys are fixed width integers
* input is large
* data already resides on the GPU
* high throughput matters
* comparison based order is unnecessary

Avoid it when keys are complex objects, custom comparison logic is required, or transfer time between CPU and GPU dominates sorting time.

## Implementation Sketch

```cuda id="4etb2e"
__device__ unsigned digit32(unsigned x, int shift, unsigned mask) {
    return (x >> shift) & mask;
}

__global__ void histogram_kernel(
    const unsigned *in,
    int n,
    int shift,
    unsigned mask,
    unsigned *block_counts
) {
    extern __shared__ unsigned local[];

    int tid = threadIdx.x;
    int block = blockIdx.x;
    int radix = mask + 1;

    for (int d = tid; d < radix; d += blockDim.x) {
        local[d] = 0;
    }
    __syncthreads();

    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) {
        unsigned d = digit32(in[i], shift, mask);
        atomicAdd(&local[d], 1);
    }
    __syncthreads();

    for (int d = tid; d < radix; d += blockDim.x) {
        block_counts[block * radix + d] = local[d];
    }
}
```

```cuda id="fwh3kc"
__global__ void scatter_kernel(
    const unsigned *in,
    unsigned *out,
    int n,
    int shift,
    unsigned mask,
    const unsigned *block_offsets
) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i >= n) return;

    unsigned d = digit32(in[i], shift, mask);

    // In a production implementation, this position is computed from
    // block_offsets plus a per-thread rank within the digit group.
    unsigned pos = block_offsets[blockIdx.x * (mask + 1) + d];

    out[pos] = in[i];
}
```

