A clear explanation of counting unique permutations where every adjacent pair sums to a perfect square using backtracking.
Problem Restatement
We are given an integer array nums.
An array is squareful if the sum of every pair of adjacent elements is a perfect square.
We need to return the number of unique permutations of nums that are squareful.
For example, if two adjacent values are 1 and 8, then:
1 + 8 = 9Since 9 is a perfect square, this adjacent pair is valid.
The array can contain duplicate values, so we must count unique permutations, not index-distinct permutations.
The official problem states that adjacent pairs must sum to a perfect square and asks for the count of distinct squareful permutations.
Input and Output
| Item | Meaning |
|---|---|
| Input | Integer array nums |
| Output | Number of unique squareful permutations |
| Valid adjacent pair | Sum is a perfect square |
| Duplicate handling | Count equal-value permutations once |
Function shape:
def numSquarefulPerms(nums: list[int]) -> int:
...Examples
Example 1:
nums = [1, 17, 8]Valid permutations:
[1, 8, 17]
[17, 8, 1]Why?
1 + 8 = 9
8 + 17 = 25Both 9 and 25 are perfect squares.
Answer:
2Example 2:
nums = [2, 2, 2]There is only one unique permutation:
[2, 2, 2]Each adjacent pair sums to:
2 + 2 = 4Since 4 is a perfect square, the permutation is valid.
Answer:
1First Thought: Generate Every Permutation
A direct solution is to generate every permutation, then check whether it is squareful.
from itertools import permutations
from math import isqrt
class Solution:
def numSquarefulPerms(self, nums: list[int]) -> int:
seen = set()
answer = 0
for perm in permutations(nums):
if perm in seen:
continue
seen.add(perm)
valid = True
for i in range(1, len(perm)):
total = perm[i - 1] + perm[i]
root = isqrt(total)
if root * root != total:
valid = False
break
if valid:
answer += 1
return answerThis follows the definition, but it is too slow because permutations grow factorially.
Problem With Plain Permutations
If n = len(nums), there can be up to:
n!permutations.
Also, duplicate values create repeated permutations that represent the same array.
We need to prune invalid paths early and avoid duplicate arrangements.
Key Insight
Think of each distinct number as a graph node.
There is an edge between two values x and y if:
x + yis a perfect square.
Then the problem becomes:
Count how many length-n walks use each value exactly as many times as it appears in nums.
Because duplicate values are handled by counts, we avoid counting the same value arrangement more than once.
Algorithm
Build a frequency map of values in nums.
Then build a graph among distinct values:
- For every pair of distinct values
xandy, add an edge ifx + yis a perfect square. - Also allow a self-edge from
xtoxifx + xis a perfect square and there are at least two copies ofx.
Then use backtracking.
The DFS state is:
dfs(previous_value, remaining_count)At each step:
- If no values remain, count one valid permutation.
- Try each next value that still has positive frequency.
- If this is the first value, it can be anything.
- Otherwise, it must be adjacent to the previous value in the graph.
- Decrease its count, recurse, then restore its count.
Correctness
The graph contains exactly the value pairs that can be adjacent in a squareful array, because an edge exists exactly when their sum is a perfect square.
The DFS builds permutations from left to right. It only appends a value if either the path is empty or the value forms a valid square-sum pair with the previous value. Therefore, every complete path counted by DFS is squareful.
The frequency map ensures that each value is used exactly as many times as it appears in nums. Since choices are made by value rather than by original index, duplicate values do not create duplicate permutations.
Conversely, any unique squareful permutation uses only values from nums with the correct frequencies. Each adjacent pair in that permutation sums to a perfect square, so every transition is present in the graph. DFS can choose those values in that order and will count that permutation.
Thus the algorithm counts exactly all unique squareful permutations.
Complexity
Let n = len(nums) and u be the number of distinct values.
| Metric | Value | Why |
|---|---|---|
| Time | O(u^2 + number_of_valid_search_states) | Build the graph, then backtrack over valid arrangements |
| Space | O(u^2 + n) | Graph storage plus recursion depth |
The search is exponential in the worst case, but pruning by square-sum adjacency and duplicate counts makes it practical for the problem constraints.
Implementation
from collections import Counter, defaultdict
from math import isqrt
class Solution:
def numSquarefulPerms(self, nums: list[int]) -> int:
count = Counter(nums)
values = list(count)
graph = defaultdict(list)
for x in values:
for y in values:
total = x + y
root = isqrt(total)
if root * root == total:
graph[x].append(y)
n = len(nums)
def dfs(prev: int | None, used: int) -> int:
if used == n:
return 1
total = 0
if prev is None:
candidates = values
else:
candidates = graph[prev]
for value in candidates:
if count[value] == 0:
continue
count[value] -= 1
total += dfs(value, used + 1)
count[value] += 1
return total
return dfs(None, 0)Code Explanation
We count how many times each value appears:
count = Counter(nums)The list values contains each distinct value once:
values = list(count)Then we build adjacency between values:
for x in values:
for y in values:A pair is valid if the sum is a perfect square:
root = isqrt(total)
if root * root == total:
graph[x].append(y)The DFS builds the permutation one value at a time:
def dfs(prev: int | None, used: int) -> int:When all positions are filled, we found one valid permutation:
if used == n:
return 1If this is the first position, any distinct value with remaining count may be chosen:
if prev is None:
candidates = valuesOtherwise, the next value must be adjacent to prev in the graph:
else:
candidates = graph[prev]We skip values already used up:
if count[value] == 0:
continueThen choose the value, recurse, and restore it:
count[value] -= 1
total += dfs(value, used + 1)
count[value] += 1Testing
def run_tests():
s = Solution()
assert s.numSquarefulPerms([1, 17, 8]) == 2
assert s.numSquarefulPerms([2, 2, 2]) == 1
assert s.numSquarefulPerms([1, 1, 8, 8]) == 3
assert s.numSquarefulPerms([1]) == 1
assert s.numSquarefulPerms([1, 2, 3]) == 0
print("all tests passed")
run_tests()| Test | Expected | Why |
|---|---|---|
[1,17,8] | 2 | Two valid orders through 8 |
[2,2,2] | 1 | Duplicate values count once |
[1,1,8,8] | 3 | Duplicate-aware counting |
[1] | 1 | Single element has no invalid adjacent pair |
[1,2,3] | 0 | No full squareful permutation exists |