A dynamic programming solution for finding the least number of perfect square numbers that sum to n.
Problem Restatement
We are given an integer n.
We need to return the least number of perfect square numbers whose sum is exactly n.
A perfect square is a number like:
1, 4, 9, 16, 25because:
1 = 1 * 1
4 = 2 * 2
9 = 3 * 3
16 = 4 * 4
25 = 5 * 5For example, for n = 12, one valid sum is:
12 = 4 + 4 + 4So the answer is 3.
The official statement asks for the least number of perfect square numbers that sum to n. It gives examples such as 12 = 4 + 4 + 4 and 13 = 4 + 9. The constraint commonly listed for this problem is 1 <= n <= 10^4.
Input and Output
| Item | Meaning |
|---|---|
| Input | An integer n |
| Output | The minimum count of perfect squares whose sum is n |
| Constraint | 1 <= n <= 10000 |
| Reuse | A square may be used more than once |
Function shape:
class Solution:
def numSquares(self, n: int) -> int:
...Examples
For:
n = 12We can write:
12 = 4 + 4 + 4So the answer is:
3For:
n = 13We can write:
13 = 4 + 9So the answer is:
2For:
n = 1We can write:
1 = 1So the answer is:
1First Thought: Brute Force
The brute force idea is to try every combination of perfect squares.
For n = 12, the usable square numbers are:
1, 4, 9We could try subtracting one square at a time:
12 - 1
12 - 4
12 - 9Then recursively solve the remaining number.
For example:
numSquares(12)
= 1 + min(numSquares(11), numSquares(8), numSquares(3))This gives the right recurrence, but plain recursion repeats the same subproblems many times.
For example, numSquares(8) may be reached from several different paths.
We need to store answers for smaller numbers.
Key Insight
Let:
dp[x]mean:
the least number of perfect squares that sum to xThe base case is:
dp[0] = 0Zero needs zero square numbers.
For every number x from 1 to n, we try every square s where s <= x.
If we use square s as the last square in the sum, then the remaining amount is:
x - sSo one candidate answer is:
dp[x - s] + 1We choose the minimum over all valid squares.
Algorithm
First, compute all perfect squares up to n.
squares = []
i = 1
while i * i <= n:
squares.append(i * i)
i += 1Then create a DP array:
dp = [0] + [float("inf")] * nNow fill it from left to right.
For each value x from 1 to n:
- Try each square
s. - Stop if
s > x. - Update:
dp[x] = min(dp[x], dp[x - s] + 1)Finally, return:
dp[n]Correctness
We prove that dp[x] stores the least number of perfect squares needed to sum to x.
The base case is dp[0] = 0, which is correct because zero needs no numbers.
Now consider any x > 0.
Every valid representation of x as a sum of perfect squares has some last square s. After removing that last square, the remaining sum is x - s.
By the time we compute dp[x], the value dp[x - s] has already been computed because x - s < x.
So the number of squares in this representation is:
dp[x - s] + 1The algorithm tries every possible perfect square s <= x, so it considers every possible choice for the last square.
It takes the minimum among these choices. Therefore dp[x] is the least possible number of perfect squares that sum to x.
By induction, this holds for every value from 0 to n. So dp[n] is the correct answer.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n sqrt(n)) | For each number up to n, we try all squares up to that number |
| Space | O(n) | The DP array stores one answer for each value from 0 to n |
There are about sqrt(n) perfect squares up to n.
So the nested loops cost:
n * sqrt(n)Implementation
class Solution:
def numSquares(self, n: int) -> int:
squares = []
i = 1
while i * i <= n:
squares.append(i * i)
i += 1
dp = [0] + [float("inf")] * n
for x in range(1, n + 1):
for square in squares:
if square > x:
break
dp[x] = min(dp[x], dp[x - square] + 1)
return dp[n]Code Explanation
We first build the list of usable perfect squares:
squares = []
i = 1
while i * i <= n:
squares.append(i * i)
i += 1For n = 12, this gives:
[1, 4, 9]Then we create the DP table:
dp = [0] + [float("inf")] * ndp[0] is 0.
Every other value starts as infinity because we have not computed it yet.
Then we compute answers for all values from 1 to n:
for x in range(1, n + 1):For each x, we try every square that can fit inside it:
for square in squares:
if square > x:
breakIf we use square, the remaining amount is x - square.
So we update the best answer:
dp[x] = min(dp[x], dp[x - square] + 1)At the end, dp[n] is the minimum number of perfect squares needed:
return dp[n]Testing
def test_num_squares():
s = Solution()
assert s.numSquares(1) == 1
assert s.numSquares(2) == 2
assert s.numSquares(3) == 3
assert s.numSquares(4) == 1
assert s.numSquares(12) == 3
assert s.numSquares(13) == 2
assert s.numSquares(43) == 3
assert s.numSquares(10000) == 1
print("all tests passed")
test_num_squares()Test meaning:
| Test | Why |
|---|---|
n = 1 | Smallest input |
n = 2 | Requires 1 + 1 |
n = 3 | Requires 1 + 1 + 1 |
n = 4 | Exact square |
n = 12 | Standard example, 4 + 4 + 4 |
n = 13 | Standard example, 4 + 9 |
n = 43 | Uses several squares, for example 25 + 9 + 9 |
n = 10000 | Large exact square |