# LeetCode 75: Sort Colors

## Problem Restatement

We are given an array `nums` containing only three possible values:

```python
0
1
2
```

These represent three colors:

| Value | Color |
|---|---|
| `0` | Red |
| `1` | White |
| `2` | Blue |

We need to sort the array in place so that all `0`s come first, then all `1`s, then all `2`s.

We cannot use the library sort function.

The official constraints are `1 <= nums.length <= 300`, and every value is either `0`, `1`, or `2`. The follow-up asks for a one-pass algorithm using constant extra space.

## Input and Output

| Item | Meaning |
|---|---|
| Input | An array `nums` containing only `0`, `1`, and `2` |
| Output | No return value |
| Mutation | Modify `nums` in place |
| Order | All `0`s, then all `1`s, then all `2`s |

Function shape:

```python
def sortColors(nums: list[int]) -> None:
    ...
```

## Examples

For:

```python
nums = [2, 0, 2, 1, 1, 0]
```

After sorting in place:

```python
nums = [0, 0, 1, 1, 2, 2]
```

For:

```python
nums = [2, 0, 1]
```

After sorting in place:

```python
nums = [0, 1, 2]
```

## First Thought: Count Each Color

Since there are only three values, we can count how many `0`s, `1`s, and `2`s appear.

Then overwrite the array.

```python
class Solution:
    def sortColors(self, nums: list[int]) -> None:
        count0 = nums.count(0)
        count1 = nums.count(1)
        count2 = nums.count(2)

        i = 0

        for _ in range(count0):
            nums[i] = 0
            i += 1

        for _ in range(count1):
            nums[i] = 1
            i += 1

        for _ in range(count2):
            nums[i] = 2
            i += 1
```

This is correct and uses constant extra space.

But it makes multiple passes over the array.

The follow-up asks for one pass.

## Key Insight

We can partition the array into three regions while scanning.

Maintain three pointers:

| Pointer | Meaning |
|---|---|
| `low` | Next position where a `0` should go |
| `mid` | Current element being inspected |
| `high` | Next position where a `2` should go |

At any time, the array is divided into regions:

```text
[0 ... low - 1]      are all 0
[low ... mid - 1]    are all 1
[mid ... high]       unknown
[high + 1 ... end]   are all 2
```

Now inspect `nums[mid]`.

There are three cases.

If `nums[mid] == 0`, swap it with `nums[low]`. Then both `low` and `mid` move right.

If `nums[mid] == 1`, it already belongs in the middle region. Only `mid` moves right.

If `nums[mid] == 2`, swap it with `nums[high]`. Then `high` moves left. Do not move `mid` yet, because the value swapped in from the right side has not been inspected.

This is the Dutch National Flag algorithm.

## Algorithm

Initialize:

```python
low = 0
mid = 0
high = len(nums) - 1
```

While:

```python
mid <= high
```

Process `nums[mid]`:

1. If it is `0`, swap with `nums[low]`, then increment `low` and `mid`.
2. If it is `1`, increment `mid`.
3. If it is `2`, swap with `nums[high]`, then decrement `high`.

The array is sorted when `mid` passes `high`.

## Walkthrough

Use:

```python
nums = [2, 0, 2, 1, 1, 0]
```

Initial state:

```text
low = 0
mid = 0
high = 5
```

`nums[mid] = 2`, so swap with `nums[high]`:

```python
[0, 0, 2, 1, 1, 2]
```

Move `high` left:

```text
low = 0
mid = 0
high = 4
```

Now `nums[mid] = 0`, so swap with `nums[low]` and move both:

```python
[0, 0, 2, 1, 1, 2]
```

```text
low = 1
mid = 1
high = 4
```

Again `nums[mid] = 0`, so swap with `nums[low]` and move both:

```text
low = 2
mid = 2
high = 4
```

Now `nums[mid] = 2`, so swap with `nums[high]`:

```python
[0, 0, 1, 1, 2, 2]
```

Move `high` left:

```text
low = 2
mid = 2
high = 3
```

Now `nums[mid] = 1`, so move `mid`.

Then another `1`, so move `mid` again.

Now `mid > high`, and the array is sorted.

## Correctness

The algorithm maintains four regions.

All elements before `low` are `0`. All elements between `low` and `mid - 1` are `1`. All elements after `high` are `2`. The region between `mid` and `high` is still unknown.

When `nums[mid]` is `0`, swapping it with `nums[low]` puts a `0` into the left region. The swapped-in element at `mid` is safe to skip because the region from `low` to `mid - 1` contained only `1`s before the swap. So both pointers move right.

When `nums[mid]` is `1`, it already belongs in the middle region, so only `mid` moves right.

When `nums[mid]` is `2`, swapping it with `nums[high]` puts a `2` into the right region. The new value at `mid` came from the unknown region, so `mid` must stay in place and inspect it next.

Each step preserves the region invariant and shrinks the unknown region. When the unknown region becomes empty, every element belongs to the correct region, so the array is sorted.

## Complexity

| Metric | Value | Why |
|---|---|---|
| Time | `O(n)` | Each element is inspected at most a constant number of times |
| Space | `O(1)` | Only three pointers are used |

## Implementation

```python
class Solution:
    def sortColors(self, nums: list[int]) -> None:
        low = 0
        mid = 0
        high = len(nums) - 1

        while mid <= high:
            if nums[mid] == 0:
                nums[low], nums[mid] = nums[mid], nums[low]
                low += 1
                mid += 1

            elif nums[mid] == 1:
                mid += 1

            else:
                nums[mid], nums[high] = nums[high], nums[mid]
                high -= 1
```

## Code Explanation

Initialize the three pointers:

```python
low = 0
mid = 0
high = len(nums) - 1
```

The loop continues while there are unknown elements:

```python
while mid <= high:
```

If the current value is `0`, move it to the left side:

```python
if nums[mid] == 0:
    nums[low], nums[mid] = nums[mid], nums[low]
    low += 1
    mid += 1
```

If the current value is `1`, leave it in the middle:

```python
elif nums[mid] == 1:
    mid += 1
```

If the current value is `2`, move it to the right side:

```python
else:
    nums[mid], nums[high] = nums[high], nums[mid]
    high -= 1
```

We do not increment `mid` in the `2` case because the new value at `mid` still needs to be checked.

## Testing

```python
def run_tests():
    s = Solution()

    nums = [2, 0, 2, 1, 1, 0]
    s.sortColors(nums)
    assert nums == [0, 0, 1, 1, 2, 2]

    nums = [2, 0, 1]
    s.sortColors(nums)
    assert nums == [0, 1, 2]

    nums = [0]
    s.sortColors(nums)
    assert nums == [0]

    nums = [1, 1, 1]
    s.sortColors(nums)
    assert nums == [1, 1, 1]

    nums = [2, 2, 0, 0]
    s.sortColors(nums)
    assert nums == [0, 0, 2, 2]

    nums = [1, 2, 0, 1, 2, 0, 1]
    s.sortColors(nums)
    assert nums == [0, 0, 1, 1, 1, 2, 2]

    print("all tests passed")

run_tests()
```

| Test | Why |
|---|---|
| `[2,0,2,1,1,0]` | Official example |
| `[2,0,1]` | Small mixed case |
| `[0]` | Single element |
| `[1,1,1]` | All values equal |
| `[2,2,0,0]` | Requires repeated swaps with `high` |
| `[1,2,0,1,2,0,1]` | Mixed order with all three values |

