Multiway merge sort generalizes merge sort by splitting the input into parts instead of two, and merging sorted sequences at each step.
It is especially useful when merging cost dominates, such as in external memory or database systems.
Problem
Given an array of length , reorder it such that:
Idea
Instead of binary splitting:
multiway merge sort splits into parts:
Each part is sorted recursively, then merged using a -way merge.
The merge step uses a priority queue or tournament tree to efficiently select the smallest element among the sequences.
Algorithm
multiway_merge_sort(A, k):
if length(A) <= 1:
return A
split A into k parts
for each part:
recursively sort part
return k_way_merge(parts)Merge step:
k_way_merge(parts):
create min heap
insert first element of each part into heap
result = []
while heap not empty:
extract minimum element
append to result
if that part has more elements:
insert next element into heap
return resultExample
Let:
Split into parts:
- [9, 2, 7]
- [4, 6, 1]
- [8, 3, 5]
Sort each:
- [2, 7, 9]
- [1, 4, 6]
- [3, 5, 8]
Merge:
Correctness
Each recursive call sorts its subarray. The -way merge always extracts the smallest remaining element among all parts, ensuring that the result is sorted. By induction on the recursion depth, the final array is sorted.
Complexity
| metric | value |
|---|---|
| recursion depth | |
| merge cost | |
| total time |
Space:
Properties
| property | value |
|---|---|
| stable | yes |
| in-place | no |
| merge structure | k-way |
| parallelism | moderate |
Notes
Multiway merge sort reduces recursion depth compared to binary merge sort, but increases merge complexity.
It is particularly effective in:
- external memory sorting
- database systems
- distributed systems
In these settings, reducing the number of merge passes can significantly improve performance.
Implementation
import heapq
def multiway_merge_sort(a, k=2):
if len(a) <= 1:
return a
size = len(a)
parts = []
step = (size + k - 1) // k
for i in range(0, size, step):
parts.append(multiway_merge_sort(a[i:i + step], k))
return k_way_merge(parts)
def k_way_merge(parts):
heap = []
result = []
for i, part in enumerate(parts):
if part:
heap.append((part[0], i, 0))
heapq.heapify(heap)
while heap:
val, part_idx, elem_idx = heapq.heappop(heap)
result.append(val)
if elem_idx + 1 < len(parts[part_idx]):
next_val = parts[part_idx][elem_idx + 1]
heapq.heappush(heap, (next_val, part_idx, elem_idx + 1))
return result