Skip to content

LeetCode 279: Perfect Squares

A dynamic programming solution for finding the least number of perfect square numbers that sum to n.

Problem Restatement

We are given an integer n.

We need to return the least number of perfect square numbers whose sum is exactly n.

A perfect square is a number like:

1, 4, 9, 16, 25

because:

1 = 1 * 1
4 = 2 * 2
9 = 3 * 3
16 = 4 * 4
25 = 5 * 5

For example, for n = 12, one valid sum is:

12 = 4 + 4 + 4

So the answer is 3.

The official statement asks for the least number of perfect square numbers that sum to n. It gives examples such as 12 = 4 + 4 + 4 and 13 = 4 + 9. The constraint commonly listed for this problem is 1 <= n <= 10^4.

Input and Output

ItemMeaning
InputAn integer n
OutputThe minimum count of perfect squares whose sum is n
Constraint1 <= n <= 10000
ReuseA square may be used more than once

Function shape:

class Solution:
    def numSquares(self, n: int) -> int:
        ...

Examples

For:

n = 12

We can write:

12 = 4 + 4 + 4

So the answer is:

3

For:

n = 13

We can write:

13 = 4 + 9

So the answer is:

2

For:

n = 1

We can write:

1 = 1

So the answer is:

1

First Thought: Brute Force

The brute force idea is to try every combination of perfect squares.

For n = 12, the usable square numbers are:

1, 4, 9

We could try subtracting one square at a time:

12 - 1
12 - 4
12 - 9

Then recursively solve the remaining number.

For example:

numSquares(12)
= 1 + min(numSquares(11), numSquares(8), numSquares(3))

This gives the right recurrence, but plain recursion repeats the same subproblems many times.

For example, numSquares(8) may be reached from several different paths.

We need to store answers for smaller numbers.

Key Insight

Let:

dp[x]

mean:

the least number of perfect squares that sum to x

The base case is:

dp[0] = 0

Zero needs zero square numbers.

For every number x from 1 to n, we try every square s where s <= x.

If we use square s as the last square in the sum, then the remaining amount is:

x - s

So one candidate answer is:

dp[x - s] + 1

We choose the minimum over all valid squares.

Algorithm

First, compute all perfect squares up to n.

squares = []
i = 1

while i * i <= n:
    squares.append(i * i)
    i += 1

Then create a DP array:

dp = [0] + [float("inf")] * n

Now fill it from left to right.

For each value x from 1 to n:

  1. Try each square s.
  2. Stop if s > x.
  3. Update:
dp[x] = min(dp[x], dp[x - s] + 1)

Finally, return:

dp[n]

Correctness

We prove that dp[x] stores the least number of perfect squares needed to sum to x.

The base case is dp[0] = 0, which is correct because zero needs no numbers.

Now consider any x > 0.

Every valid representation of x as a sum of perfect squares has some last square s. After removing that last square, the remaining sum is x - s.

By the time we compute dp[x], the value dp[x - s] has already been computed because x - s < x.

So the number of squares in this representation is:

dp[x - s] + 1

The algorithm tries every possible perfect square s <= x, so it considers every possible choice for the last square.

It takes the minimum among these choices. Therefore dp[x] is the least possible number of perfect squares that sum to x.

By induction, this holds for every value from 0 to n. So dp[n] is the correct answer.

Complexity

MetricValueWhy
TimeO(n sqrt(n))For each number up to n, we try all squares up to that number
SpaceO(n)The DP array stores one answer for each value from 0 to n

There are about sqrt(n) perfect squares up to n.

So the nested loops cost:

n * sqrt(n)

Implementation

class Solution:
    def numSquares(self, n: int) -> int:
        squares = []
        i = 1

        while i * i <= n:
            squares.append(i * i)
            i += 1

        dp = [0] + [float("inf")] * n

        for x in range(1, n + 1):
            for square in squares:
                if square > x:
                    break

                dp[x] = min(dp[x], dp[x - square] + 1)

        return dp[n]

Code Explanation

We first build the list of usable perfect squares:

squares = []
i = 1

while i * i <= n:
    squares.append(i * i)
    i += 1

For n = 12, this gives:

[1, 4, 9]

Then we create the DP table:

dp = [0] + [float("inf")] * n

dp[0] is 0.

Every other value starts as infinity because we have not computed it yet.

Then we compute answers for all values from 1 to n:

for x in range(1, n + 1):

For each x, we try every square that can fit inside it:

for square in squares:
    if square > x:
        break

If we use square, the remaining amount is x - square.

So we update the best answer:

dp[x] = min(dp[x], dp[x - square] + 1)

At the end, dp[n] is the minimum number of perfect squares needed:

return dp[n]

Testing

def test_num_squares():
    s = Solution()

    assert s.numSquares(1) == 1
    assert s.numSquares(2) == 2
    assert s.numSquares(3) == 3
    assert s.numSquares(4) == 1
    assert s.numSquares(12) == 3
    assert s.numSquares(13) == 2
    assert s.numSquares(43) == 3
    assert s.numSquares(10000) == 1

    print("all tests passed")

test_num_squares()

Test meaning:

TestWhy
n = 1Smallest input
n = 2Requires 1 + 1
n = 3Requires 1 + 1 + 1
n = 4Exact square
n = 12Standard example, 4 + 4 + 4
n = 13Standard example, 4 + 9
n = 43Uses several squares, for example 25 + 9 + 9
n = 10000Large exact square