A clear explanation of Counting Bits using dynamic programming and bit manipulation.
Problem Restatement
We are given an integer n.
Return an array ans of length n + 1.
For every integer i from 0 to n, ans[i] should be the number of 1 bits in the binary representation of i. The official problem asks for this array for all values 0 <= i <= n.
For example:
n = 5The binary forms are:
0 -> 0 has 0 ones
1 -> 1 has 1 one
2 -> 10 has 1 one
3 -> 11 has 2 ones
4 -> 100 has 1 one
5 -> 101 has 2 onesSo the answer is:
[0, 1, 1, 2, 1, 2]Input and Output
| Item | Meaning |
|---|---|
| Input | Integer n |
| Output | Array ans of length n + 1 |
ans[i] | Number of 1 bits in binary form of i |
| Range | All integers from 0 to n |
Function shape:
def countBits(n: int) -> list[int]:
...Examples
Example 1:
Input: n = 2
Output: [0, 1, 1]Explanation:
0 -> 0 has 0 ones
1 -> 1 has 1 one
2 -> 10 has 1 oneExample 2:
Input: n = 5
Output: [0, 1, 1, 2, 1, 2]Explanation:
0 -> 0
1 -> 1
2 -> 10
3 -> 11
4 -> 100
5 -> 101The counts are:
0, 1, 1, 2, 1, 2First Thought: Count Bits One Number at a Time
A direct solution is to loop from 0 to n.
For each number, repeatedly check its last bit.
class Solution:
def countBits(self, n: int) -> list[int]:
ans = []
for x in range(n + 1):
count = 0
value = x
while value > 0:
count += value & 1
value >>= 1
ans.append(count)
return ansThis works, but it repeats work.
For every number, we inspect its binary digits again.
If n is large, this costs about:
O(n log n)We can do better by reusing previous answers.
Key Insight
Every number i can be related to a smaller number by removing its last binary bit.
Removing the last bit is the same as:
i >> 1The last bit is:
i & 1So the number of 1 bits in i is:
bits(i) = bits(i >> 1) + (i & 1)This works because i >> 1 removes the least significant bit, and i & 1 tells us whether the removed bit was 1.
Example:
i = 5
binary = 101
i >> 1 = 2
binary of 2 = 10
i & 1 = 1So:
bits(5) = bits(2) + 1
= 1 + 1
= 2Algorithm
Create an array ans of length n + 1.
Set:
ans[0] = 0Then for every i from 1 to n:
ans[i] = ans[i >> 1] + (i & 1)Return ans.
Walkthrough
Use:
n = 5Start:
ans = [0, 0, 0, 0, 0, 0]For i = 1:
1 >> 1 = 0
1 & 1 = 1
ans[1] = ans[0] + 1 = 1For i = 2:
2 >> 1 = 1
2 & 1 = 0
ans[2] = ans[1] + 0 = 1For i = 3:
3 >> 1 = 1
3 & 1 = 1
ans[3] = ans[1] + 1 = 2For i = 4:
4 >> 1 = 2
4 & 1 = 0
ans[4] = ans[2] + 0 = 1For i = 5:
5 >> 1 = 2
5 & 1 = 1
ans[5] = ans[2] + 1 = 2Final answer:
[0, 1, 1, 2, 1, 2]Correctness
For every integer i > 0, the expression i >> 1 removes the last binary bit of i.
The expression i & 1 gives the value of that removed bit:
| Last bit | i & 1 |
|---|---|
0 | 0 |
1 | 1 |
Therefore, the number of 1 bits in i equals the number of 1 bits in i >> 1, plus 1 if the removed bit was 1.
The algorithm fills ans from smaller numbers to larger numbers. Since i >> 1 is always smaller than i for every i > 0, ans[i >> 1] has already been computed when we compute ans[i].
The base case ans[0] = 0 is correct because the binary representation of 0 contains no 1 bits.
Thus, every ans[i] is computed correctly, and the returned array is correct for all values from 0 to n.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | We compute one value for each integer from 0 to n |
| Space | O(n) | The output array has length n + 1 |
The extra working space besides the returned array is O(1).
Implementation
class Solution:
def countBits(self, n: int) -> list[int]:
ans = [0] * (n + 1)
for i in range(1, n + 1):
ans[i] = ans[i >> 1] + (i & 1)
return ansCode Explanation
Create the result array:
ans = [0] * (n + 1)We need length n + 1 because the result includes both 0 and n.
Then compute answers from 1 to n:
for i in range(1, n + 1):Use the recurrence:
ans[i] = ans[i >> 1] + (i & 1)Here:
| Expression | Meaning |
|---|---|
i >> 1 | i without its last binary bit |
i & 1 | The last binary bit of i |
Finally return the array:
return ansTesting
def run_tests():
s = Solution()
assert s.countBits(0) == [0]
assert s.countBits(1) == [0, 1]
assert s.countBits(2) == [0, 1, 1]
assert s.countBits(5) == [0, 1, 1, 2, 1, 2]
assert s.countBits(8) == [0, 1, 1, 2, 1, 2, 2, 3, 1]
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
0 | Smallest input |
1 | First positive number |
2 | First power of two |
5 | Standard sample |
8 | Checks another power of two |