Skip to content

LeetCode 338: Counting Bits

A clear explanation of Counting Bits using dynamic programming and bit manipulation.

Problem Restatement

We are given an integer n.

Return an array ans of length n + 1.

For every integer i from 0 to n, ans[i] should be the number of 1 bits in the binary representation of i. The official problem asks for this array for all values 0 <= i <= n.

For example:

n = 5

The binary forms are:

0 -> 0      has 0 ones
1 -> 1      has 1 one
2 -> 10     has 1 one
3 -> 11     has 2 ones
4 -> 100    has 1 one
5 -> 101    has 2 ones

So the answer is:

[0, 1, 1, 2, 1, 2]

Input and Output

ItemMeaning
InputInteger n
OutputArray ans of length n + 1
ans[i]Number of 1 bits in binary form of i
RangeAll integers from 0 to n

Function shape:

def countBits(n: int) -> list[int]:
    ...

Examples

Example 1:

Input: n = 2
Output: [0, 1, 1]

Explanation:

0 -> 0   has 0 ones
1 -> 1   has 1 one
2 -> 10  has 1 one

Example 2:

Input: n = 5
Output: [0, 1, 1, 2, 1, 2]

Explanation:

0 -> 0
1 -> 1
2 -> 10
3 -> 11
4 -> 100
5 -> 101

The counts are:

0, 1, 1, 2, 1, 2

First Thought: Count Bits One Number at a Time

A direct solution is to loop from 0 to n.

For each number, repeatedly check its last bit.

class Solution:
    def countBits(self, n: int) -> list[int]:
        ans = []

        for x in range(n + 1):
            count = 0
            value = x

            while value > 0:
                count += value & 1
                value >>= 1

            ans.append(count)

        return ans

This works, but it repeats work.

For every number, we inspect its binary digits again.

If n is large, this costs about:

O(n log n)

We can do better by reusing previous answers.

Key Insight

Every number i can be related to a smaller number by removing its last binary bit.

Removing the last bit is the same as:

i >> 1

The last bit is:

i & 1

So the number of 1 bits in i is:

bits(i) = bits(i >> 1) + (i & 1)

This works because i >> 1 removes the least significant bit, and i & 1 tells us whether the removed bit was 1.

Example:

i = 5
binary = 101

i >> 1 = 2
binary of 2 = 10

i & 1 = 1

So:

bits(5) = bits(2) + 1
        = 1 + 1
        = 2

Algorithm

Create an array ans of length n + 1.

Set:

ans[0] = 0

Then for every i from 1 to n:

ans[i] = ans[i >> 1] + (i & 1)

Return ans.

Walkthrough

Use:

n = 5

Start:

ans = [0, 0, 0, 0, 0, 0]

For i = 1:

1 >> 1 = 0
1 & 1 = 1
ans[1] = ans[0] + 1 = 1

For i = 2:

2 >> 1 = 1
2 & 1 = 0
ans[2] = ans[1] + 0 = 1

For i = 3:

3 >> 1 = 1
3 & 1 = 1
ans[3] = ans[1] + 1 = 2

For i = 4:

4 >> 1 = 2
4 & 1 = 0
ans[4] = ans[2] + 0 = 1

For i = 5:

5 >> 1 = 2
5 & 1 = 1
ans[5] = ans[2] + 1 = 2

Final answer:

[0, 1, 1, 2, 1, 2]

Correctness

For every integer i > 0, the expression i >> 1 removes the last binary bit of i.

The expression i & 1 gives the value of that removed bit:

Last biti & 1
00
11

Therefore, the number of 1 bits in i equals the number of 1 bits in i >> 1, plus 1 if the removed bit was 1.

The algorithm fills ans from smaller numbers to larger numbers. Since i >> 1 is always smaller than i for every i > 0, ans[i >> 1] has already been computed when we compute ans[i].

The base case ans[0] = 0 is correct because the binary representation of 0 contains no 1 bits.

Thus, every ans[i] is computed correctly, and the returned array is correct for all values from 0 to n.

Complexity

MetricValueWhy
TimeO(n)We compute one value for each integer from 0 to n
SpaceO(n)The output array has length n + 1

The extra working space besides the returned array is O(1).

Implementation

class Solution:
    def countBits(self, n: int) -> list[int]:
        ans = [0] * (n + 1)

        for i in range(1, n + 1):
            ans[i] = ans[i >> 1] + (i & 1)

        return ans

Code Explanation

Create the result array:

ans = [0] * (n + 1)

We need length n + 1 because the result includes both 0 and n.

Then compute answers from 1 to n:

for i in range(1, n + 1):

Use the recurrence:

ans[i] = ans[i >> 1] + (i & 1)

Here:

ExpressionMeaning
i >> 1i without its last binary bit
i & 1The last binary bit of i

Finally return the array:

return ans

Testing

def run_tests():
    s = Solution()

    assert s.countBits(0) == [0]
    assert s.countBits(1) == [0, 1]
    assert s.countBits(2) == [0, 1, 1]
    assert s.countBits(5) == [0, 1, 1, 2, 1, 2]
    assert s.countBits(8) == [0, 1, 1, 2, 1, 2, 2, 3, 1]

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
0Smallest input
1First positive number
2First power of two
5Standard sample
8Checks another power of two