A clear guide to sorting an array of 0s, 1s, and 2s in place using the Dutch National Flag algorithm.
Problem Restatement
We are given an array nums containing only three possible values:
0
1
2These represent three colors:
| Value | Color |
|---|---|
0 | Red |
1 | White |
2 | Blue |
We need to sort the array in place so that all 0s come first, then all 1s, then all 2s.
We cannot use the library sort function.
The official constraints are 1 <= nums.length <= 300, and every value is either 0, 1, or 2. The follow-up asks for a one-pass algorithm using constant extra space.
Input and Output
| Item | Meaning |
|---|---|
| Input | An array nums containing only 0, 1, and 2 |
| Output | No return value |
| Mutation | Modify nums in place |
| Order | All 0s, then all 1s, then all 2s |
Function shape:
def sortColors(nums: list[int]) -> None:
...Examples
For:
nums = [2, 0, 2, 1, 1, 0]After sorting in place:
nums = [0, 0, 1, 1, 2, 2]For:
nums = [2, 0, 1]After sorting in place:
nums = [0, 1, 2]First Thought: Count Each Color
Since there are only three values, we can count how many 0s, 1s, and 2s appear.
Then overwrite the array.
class Solution:
def sortColors(self, nums: list[int]) -> None:
count0 = nums.count(0)
count1 = nums.count(1)
count2 = nums.count(2)
i = 0
for _ in range(count0):
nums[i] = 0
i += 1
for _ in range(count1):
nums[i] = 1
i += 1
for _ in range(count2):
nums[i] = 2
i += 1This is correct and uses constant extra space.
But it makes multiple passes over the array.
The follow-up asks for one pass.
Key Insight
We can partition the array into three regions while scanning.
Maintain three pointers:
| Pointer | Meaning |
|---|---|
low | Next position where a 0 should go |
mid | Current element being inspected |
high | Next position where a 2 should go |
At any time, the array is divided into regions:
[0 ... low - 1] are all 0
[low ... mid - 1] are all 1
[mid ... high] unknown
[high + 1 ... end] are all 2Now inspect nums[mid].
There are three cases.
If nums[mid] == 0, swap it with nums[low]. Then both low and mid move right.
If nums[mid] == 1, it already belongs in the middle region. Only mid moves right.
If nums[mid] == 2, swap it with nums[high]. Then high moves left. Do not move mid yet, because the value swapped in from the right side has not been inspected.
This is the Dutch National Flag algorithm.
Algorithm
Initialize:
low = 0
mid = 0
high = len(nums) - 1While:
mid <= highProcess nums[mid]:
- If it is
0, swap withnums[low], then incrementlowandmid. - If it is
1, incrementmid. - If it is
2, swap withnums[high], then decrementhigh.
The array is sorted when mid passes high.
Walkthrough
Use:
nums = [2, 0, 2, 1, 1, 0]Initial state:
low = 0
mid = 0
high = 5nums[mid] = 2, so swap with nums[high]:
[0, 0, 2, 1, 1, 2]Move high left:
low = 0
mid = 0
high = 4Now nums[mid] = 0, so swap with nums[low] and move both:
[0, 0, 2, 1, 1, 2]low = 1
mid = 1
high = 4Again nums[mid] = 0, so swap with nums[low] and move both:
low = 2
mid = 2
high = 4Now nums[mid] = 2, so swap with nums[high]:
[0, 0, 1, 1, 2, 2]Move high left:
low = 2
mid = 2
high = 3Now nums[mid] = 1, so move mid.
Then another 1, so move mid again.
Now mid > high, and the array is sorted.
Correctness
The algorithm maintains four regions.
All elements before low are 0. All elements between low and mid - 1 are 1. All elements after high are 2. The region between mid and high is still unknown.
When nums[mid] is 0, swapping it with nums[low] puts a 0 into the left region. The swapped-in element at mid is safe to skip because the region from low to mid - 1 contained only 1s before the swap. So both pointers move right.
When nums[mid] is 1, it already belongs in the middle region, so only mid moves right.
When nums[mid] is 2, swapping it with nums[high] puts a 2 into the right region. The new value at mid came from the unknown region, so mid must stay in place and inspect it next.
Each step preserves the region invariant and shrinks the unknown region. When the unknown region becomes empty, every element belongs to the correct region, so the array is sorted.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each element is inspected at most a constant number of times |
| Space | O(1) | Only three pointers are used |
Implementation
class Solution:
def sortColors(self, nums: list[int]) -> None:
low = 0
mid = 0
high = len(nums) - 1
while mid <= high:
if nums[mid] == 0:
nums[low], nums[mid] = nums[mid], nums[low]
low += 1
mid += 1
elif nums[mid] == 1:
mid += 1
else:
nums[mid], nums[high] = nums[high], nums[mid]
high -= 1Code Explanation
Initialize the three pointers:
low = 0
mid = 0
high = len(nums) - 1The loop continues while there are unknown elements:
while mid <= high:If the current value is 0, move it to the left side:
if nums[mid] == 0:
nums[low], nums[mid] = nums[mid], nums[low]
low += 1
mid += 1If the current value is 1, leave it in the middle:
elif nums[mid] == 1:
mid += 1If the current value is 2, move it to the right side:
else:
nums[mid], nums[high] = nums[high], nums[mid]
high -= 1We do not increment mid in the 2 case because the new value at mid still needs to be checked.
Testing
def run_tests():
s = Solution()
nums = [2, 0, 2, 1, 1, 0]
s.sortColors(nums)
assert nums == [0, 0, 1, 1, 2, 2]
nums = [2, 0, 1]
s.sortColors(nums)
assert nums == [0, 1, 2]
nums = [0]
s.sortColors(nums)
assert nums == [0]
nums = [1, 1, 1]
s.sortColors(nums)
assert nums == [1, 1, 1]
nums = [2, 2, 0, 0]
s.sortColors(nums)
assert nums == [0, 0, 2, 2]
nums = [1, 2, 0, 1, 2, 0, 1]
s.sortColors(nums)
assert nums == [0, 0, 1, 1, 1, 2, 2]
print("all tests passed")
run_tests()| Test | Why |
|---|---|
[2,0,2,1,1,0] | Official example |
[2,0,1] | Small mixed case |
[0] | Single element |
[1,1,1] | All values equal |
[2,2,0,0] | Requires repeated swaps with high |
[1,2,0,1,2,0,1] | Mixed order with all three values |