Skip to content

LeetCode 996: Number of Squareful Arrays

A clear explanation of counting unique permutations where every adjacent pair sums to a perfect square using backtracking.

Problem Restatement

We are given an integer array nums.

An array is squareful if the sum of every pair of adjacent elements is a perfect square.

We need to return the number of unique permutations of nums that are squareful.

For example, if two adjacent values are 1 and 8, then:

1 + 8 = 9

Since 9 is a perfect square, this adjacent pair is valid.

The array can contain duplicate values, so we must count unique permutations, not index-distinct permutations.

The official problem states that adjacent pairs must sum to a perfect square and asks for the count of distinct squareful permutations.

Input and Output

ItemMeaning
InputInteger array nums
OutputNumber of unique squareful permutations
Valid adjacent pairSum is a perfect square
Duplicate handlingCount equal-value permutations once

Function shape:

def numSquarefulPerms(nums: list[int]) -> int:
    ...

Examples

Example 1:

nums = [1, 17, 8]

Valid permutations:

[1, 8, 17]
[17, 8, 1]

Why?

1 + 8 = 9
8 + 17 = 25

Both 9 and 25 are perfect squares.

Answer:

2

Example 2:

nums = [2, 2, 2]

There is only one unique permutation:

[2, 2, 2]

Each adjacent pair sums to:

2 + 2 = 4

Since 4 is a perfect square, the permutation is valid.

Answer:

1

First Thought: Generate Every Permutation

A direct solution is to generate every permutation, then check whether it is squareful.

from itertools import permutations
from math import isqrt

class Solution:
    def numSquarefulPerms(self, nums: list[int]) -> int:
        seen = set()
        answer = 0

        for perm in permutations(nums):
            if perm in seen:
                continue

            seen.add(perm)

            valid = True
            for i in range(1, len(perm)):
                total = perm[i - 1] + perm[i]
                root = isqrt(total)

                if root * root != total:
                    valid = False
                    break

            if valid:
                answer += 1

        return answer

This follows the definition, but it is too slow because permutations grow factorially.

Problem With Plain Permutations

If n = len(nums), there can be up to:

n!

permutations.

Also, duplicate values create repeated permutations that represent the same array.

We need to prune invalid paths early and avoid duplicate arrangements.

Key Insight

Think of each distinct number as a graph node.

There is an edge between two values x and y if:

x + y

is a perfect square.

Then the problem becomes:

Count how many length-n walks use each value exactly as many times as it appears in nums.

Because duplicate values are handled by counts, we avoid counting the same value arrangement more than once.

Algorithm

Build a frequency map of values in nums.

Then build a graph among distinct values:

  1. For every pair of distinct values x and y, add an edge if x + y is a perfect square.
  2. Also allow a self-edge from x to x if x + x is a perfect square and there are at least two copies of x.

Then use backtracking.

The DFS state is:

dfs(previous_value, remaining_count)

At each step:

  1. If no values remain, count one valid permutation.
  2. Try each next value that still has positive frequency.
  3. If this is the first value, it can be anything.
  4. Otherwise, it must be adjacent to the previous value in the graph.
  5. Decrease its count, recurse, then restore its count.

Correctness

The graph contains exactly the value pairs that can be adjacent in a squareful array, because an edge exists exactly when their sum is a perfect square.

The DFS builds permutations from left to right. It only appends a value if either the path is empty or the value forms a valid square-sum pair with the previous value. Therefore, every complete path counted by DFS is squareful.

The frequency map ensures that each value is used exactly as many times as it appears in nums. Since choices are made by value rather than by original index, duplicate values do not create duplicate permutations.

Conversely, any unique squareful permutation uses only values from nums with the correct frequencies. Each adjacent pair in that permutation sums to a perfect square, so every transition is present in the graph. DFS can choose those values in that order and will count that permutation.

Thus the algorithm counts exactly all unique squareful permutations.

Complexity

Let n = len(nums) and u be the number of distinct values.

MetricValueWhy
TimeO(u^2 + number_of_valid_search_states)Build the graph, then backtrack over valid arrangements
SpaceO(u^2 + n)Graph storage plus recursion depth

The search is exponential in the worst case, but pruning by square-sum adjacency and duplicate counts makes it practical for the problem constraints.

Implementation

from collections import Counter, defaultdict
from math import isqrt

class Solution:
    def numSquarefulPerms(self, nums: list[int]) -> int:
        count = Counter(nums)
        values = list(count)

        graph = defaultdict(list)

        for x in values:
            for y in values:
                total = x + y
                root = isqrt(total)

                if root * root == total:
                    graph[x].append(y)

        n = len(nums)

        def dfs(prev: int | None, used: int) -> int:
            if used == n:
                return 1

            total = 0

            if prev is None:
                candidates = values
            else:
                candidates = graph[prev]

            for value in candidates:
                if count[value] == 0:
                    continue

                count[value] -= 1
                total += dfs(value, used + 1)
                count[value] += 1

            return total

        return dfs(None, 0)

Code Explanation

We count how many times each value appears:

count = Counter(nums)

The list values contains each distinct value once:

values = list(count)

Then we build adjacency between values:

for x in values:
    for y in values:

A pair is valid if the sum is a perfect square:

root = isqrt(total)

if root * root == total:
    graph[x].append(y)

The DFS builds the permutation one value at a time:

def dfs(prev: int | None, used: int) -> int:

When all positions are filled, we found one valid permutation:

if used == n:
    return 1

If this is the first position, any distinct value with remaining count may be chosen:

if prev is None:
    candidates = values

Otherwise, the next value must be adjacent to prev in the graph:

else:
    candidates = graph[prev]

We skip values already used up:

if count[value] == 0:
    continue

Then choose the value, recurse, and restore it:

count[value] -= 1
total += dfs(value, used + 1)
count[value] += 1

Testing

def run_tests():
    s = Solution()

    assert s.numSquarefulPerms([1, 17, 8]) == 2
    assert s.numSquarefulPerms([2, 2, 2]) == 1
    assert s.numSquarefulPerms([1, 1, 8, 8]) == 3
    assert s.numSquarefulPerms([1]) == 1
    assert s.numSquarefulPerms([1, 2, 3]) == 0

    print("all tests passed")

run_tests()
TestExpectedWhy
[1,17,8]2Two valid orders through 8
[2,2,2]1Duplicate values count once
[1,1,8,8]3Duplicate-aware counting
[1]1Single element has no invalid adjacent pair
[1,2,3]0No full squareful permutation exists