# LeetCode 338: Counting Bits

## Problem Restatement

We are given an integer `n`.

Return an array `ans` of length `n + 1`.

For every integer `i` from `0` to `n`, `ans[i]` should be the number of `1` bits in the binary representation of `i`. The official problem asks for this array for all values `0 <= i <= n`.

For example:

```text
n = 5
```

The binary forms are:

```text
0 -> 0      has 0 ones
1 -> 1      has 1 one
2 -> 10     has 1 one
3 -> 11     has 2 ones
4 -> 100    has 1 one
5 -> 101    has 2 ones
```

So the answer is:

```text
[0, 1, 1, 2, 1, 2]
```

## Input and Output

| Item | Meaning |
|---|---|
| Input | Integer `n` |
| Output | Array `ans` of length `n + 1` |
| `ans[i]` | Number of `1` bits in binary form of `i` |
| Range | All integers from `0` to `n` |

Function shape:

```python
def countBits(n: int) -> list[int]:
    ...
```

## Examples

Example 1:

```text
Input: n = 2
Output: [0, 1, 1]
```

Explanation:

```text
0 -> 0   has 0 ones
1 -> 1   has 1 one
2 -> 10  has 1 one
```

Example 2:

```text
Input: n = 5
Output: [0, 1, 1, 2, 1, 2]
```

Explanation:

```text
0 -> 0
1 -> 1
2 -> 10
3 -> 11
4 -> 100
5 -> 101
```

The counts are:

```text
0, 1, 1, 2, 1, 2
```

## First Thought: Count Bits One Number at a Time

A direct solution is to loop from `0` to `n`.

For each number, repeatedly check its last bit.

```python
class Solution:
    def countBits(self, n: int) -> list[int]:
        ans = []

        for x in range(n + 1):
            count = 0
            value = x

            while value > 0:
                count += value & 1
                value >>= 1

            ans.append(count)

        return ans
```

This works, but it repeats work.

For every number, we inspect its binary digits again.

If `n` is large, this costs about:

```text
O(n log n)
```

We can do better by reusing previous answers.

## Key Insight

Every number `i` can be related to a smaller number by removing its last binary bit.

Removing the last bit is the same as:

```text
i >> 1
```

The last bit is:

```text
i & 1
```

So the number of `1` bits in `i` is:

```text
bits(i) = bits(i >> 1) + (i & 1)
```

This works because `i >> 1` removes the least significant bit, and `i & 1` tells us whether the removed bit was `1`.

Example:

```text
i = 5
binary = 101

i >> 1 = 2
binary of 2 = 10

i & 1 = 1
```

So:

```text
bits(5) = bits(2) + 1
        = 1 + 1
        = 2
```

## Algorithm

Create an array `ans` of length `n + 1`.

Set:

```python
ans[0] = 0
```

Then for every `i` from `1` to `n`:

```python
ans[i] = ans[i >> 1] + (i & 1)
```

Return `ans`.

## Walkthrough

Use:

```text
n = 5
```

Start:

```text
ans = [0, 0, 0, 0, 0, 0]
```

For `i = 1`:

```text
1 >> 1 = 0
1 & 1 = 1
ans[1] = ans[0] + 1 = 1
```

For `i = 2`:

```text
2 >> 1 = 1
2 & 1 = 0
ans[2] = ans[1] + 0 = 1
```

For `i = 3`:

```text
3 >> 1 = 1
3 & 1 = 1
ans[3] = ans[1] + 1 = 2
```

For `i = 4`:

```text
4 >> 1 = 2
4 & 1 = 0
ans[4] = ans[2] + 0 = 1
```

For `i = 5`:

```text
5 >> 1 = 2
5 & 1 = 1
ans[5] = ans[2] + 1 = 2
```

Final answer:

```text
[0, 1, 1, 2, 1, 2]
```

## Correctness

For every integer `i > 0`, the expression `i >> 1` removes the last binary bit of `i`.

The expression `i & 1` gives the value of that removed bit:

| Last bit | `i & 1` |
|---|---:|
| `0` | `0` |
| `1` | `1` |

Therefore, the number of `1` bits in `i` equals the number of `1` bits in `i >> 1`, plus `1` if the removed bit was `1`.

The algorithm fills `ans` from smaller numbers to larger numbers. Since `i >> 1` is always smaller than `i` for every `i > 0`, `ans[i >> 1]` has already been computed when we compute `ans[i]`.

The base case `ans[0] = 0` is correct because the binary representation of `0` contains no `1` bits.

Thus, every `ans[i]` is computed correctly, and the returned array is correct for all values from `0` to `n`.

## Complexity

| Metric | Value | Why |
|---|---|---|
| Time | `O(n)` | We compute one value for each integer from `0` to `n` |
| Space | `O(n)` | The output array has length `n + 1` |

The extra working space besides the returned array is `O(1)`.

## Implementation

```python
class Solution:
    def countBits(self, n: int) -> list[int]:
        ans = [0] * (n + 1)

        for i in range(1, n + 1):
            ans[i] = ans[i >> 1] + (i & 1)

        return ans
```

## Code Explanation

Create the result array:

```python
ans = [0] * (n + 1)
```

We need length `n + 1` because the result includes both `0` and `n`.

Then compute answers from `1` to `n`:

```python
for i in range(1, n + 1):
```

Use the recurrence:

```python
ans[i] = ans[i >> 1] + (i & 1)
```

Here:

| Expression | Meaning |
|---|---|
| `i >> 1` | `i` without its last binary bit |
| `i & 1` | The last binary bit of `i` |

Finally return the array:

```python
return ans
```

## Testing

```python
def run_tests():
    s = Solution()

    assert s.countBits(0) == [0]
    assert s.countBits(1) == [0, 1]
    assert s.countBits(2) == [0, 1, 1]
    assert s.countBits(5) == [0, 1, 1, 2, 1, 2]
    assert s.countBits(8) == [0, 1, 1, 2, 1, 2, 2, 3, 1]

    print("all tests passed")

run_tests()
```

Test meaning:

| Test | Why |
|---|---|
| `0` | Smallest input |
| `1` | First positive number |
| `2` | First power of two |
| `5` | Standard sample |
| `8` | Checks another power of two |

