Skip to content

In Place Radix Sort

Sort digit based keys by redistributing elements inside the original array with only small auxiliary bucket state.

In place radix sort sorts keys by digits while avoiding a full auxiliary output array. It keeps only small bucket metadata, such as counts, offsets, and next write positions.

The common form is MSD based. At each digit position, the algorithm partitions the current subarray into digit buckets, then recursively sorts each bucket by the next digit.

Problem

Given an array AA of keys whose digits can be extracted in base bb, sort AA in place.

Each key has digit positions:

d0,d1,,dm1 d_0, d_1, \ldots, d_{m-1}

For lexicographic or numeric MSD order, process the most significant digit first.

Idea

For each subarray:

  1. Count digit frequencies
  2. Convert counts into bucket ranges
  3. Move each element into its bucket by swapping
  4. Recursively sort each bucket on the next digit

The array itself stores the reordered records. The algorithm only keeps bucket state.

Algorithm

in_place_radix_sort(A, lo, hi, d, base):
    if hi - lo <= 1:
        return

    count = [0] * base

    for i from lo to hi - 1:
        count[digit(A[i], d)] += 1

    start = [0] * base
    start[0] = lo

    for r from 1 to base - 1:
        start[r] = start[r - 1] + count[r - 1]

    next = copy(start)
    end = [start[r] + count[r] for r in 0..base-1]

    i = lo
    while i < hi:
        r = digit(A[i], d)

        if start[r] <= i < end[r]:
            i += 1
        else:
            swap A[i] with A[next[r]]
            next[r] += 1

    for r from 0 to base - 1:
        if end[r] - start[r] > 1:
            in_place_radix_sort(A, start[r], end[r], d + 1, base)

Example

Sort:

[329,457,657,839,436,720,355] [329, 457, 657, 839, 436, 720, 355]

Using MSD decimal digits, the first pass groups by hundreds digit:

bucketvalues
3329, 355
4457, 436
6657
7720
8839

Then each bucket with more than one element is sorted by tens digit, and then by ones digit.

Final result:

[329,355,436,457,657,720,839] [329, 355, 436, 457, 657, 720, 839]

Correctness

The count phase computes the exact size of every digit bucket. The prefix phase maps each bucket to a disjoint interval in the array. The swapping phase moves every element into the interval that matches its current digit.

After partitioning, all elements in lower digit buckets precede all elements in higher digit buckets. Recursion sorts elements inside each bucket by the remaining digits. Therefore the final array is sorted by the whole key.

Complexity

Let:

  • nn be the number of keys
  • mm be the number of digit positions
  • bb be the radix base

Each level scans its active subarrays once and uses O(b)O(b) bucket metadata.

Worst case time:

O(nm) O(nm)

Auxiliary space per recursion level:

O(b) O(b)

Total stack and bucket space depends on implementation. With careful reuse of bucket arrays, the practical extra space is small.

Properties

propertyvalue
stableno
in placeyes
comparison basedno
adaptivepartly, through bucket recursion
typical directionMSD

When to Use

Use in place radix sort when keys have a compact digit representation and memory overhead matters. It is useful for large integer arrays, fixed width byte keys, and records where allocating another array of size nn is expensive.

Avoid it when stability is required, when digit extraction is expensive, or when a simpler stable radix implementation is easier to maintain.

Implementation

def in_place_radix_sort(a, digits, base=10):
    def digit(x, d):
        shift = digits - 1 - d
        return (x // (base ** shift)) % base

    def sort(lo, hi, d):
        if hi - lo <= 1 or d >= digits:
            return

        count = [0] * base

        for i in range(lo, hi):
            count[digit(a[i], d)] += 1

        start = [0] * base
        start[0] = lo

        for r in range(1, base):
            start[r] = start[r - 1] + count[r - 1]

        end = [start[r] + count[r] for r in range(base)]
        next_pos = start.copy()

        i = lo
        while i < hi:
            r = digit(a[i], d)

            if start[r] <= i < end[r]:
                i += 1
            else:
                j = next_pos[r]
                a[i], a[j] = a[j], a[i]
                next_pos[r] += 1

        for r in range(base):
            if end[r] - start[r] > 1:
                sort(start[r], end[r], d + 1)

    sort(0, len(a), 0)
    return a