Skip to content

LeetCode 425: Word Squares

A clear explanation of building all word squares using backtracking with prefix pruning.

Problem Restatement

We are given an array of unique strings:

words

All words have the same length.

Return all word squares that can be built from these words.

The same word may be used multiple times.

A word square is a list of words where the kth row and the kth column read the same string.

The answer may be returned in any order.

The source statement says all input words are unique, all words have the same length, and the same word can be used multiple times in different word squares. The constraints include 1 <= words.length <= 1000 and word length up to 4 in the current LeetCode statement.

Input and Output

ItemMeaning
InputA list of unique equal-length words
OutputAll valid word squares
Word lengthAll words have the same length
Reuse ruleThe same word may be used multiple times
OrderOutput order does not matter

Example function shape:

def wordSquares(words: list[str]) -> list[list[str]]:
    ...

Examples

Example 1:

words = ["area", "lead", "wall", "lady", "ball"]

One valid output is:

[
    ["ball", "area", "lead", "lady"],
    ["wall", "area", "lead", "lady"],
]

The first square is:

b a l l
a r e a
l e a d
l a d y

Rows:

ball
area
lead
lady

Columns:

ball
area
lead
lady

They match.

Example 2:

words = ["abat", "baba", "atan", "atal"]

One valid output is:

[
    ["baba", "abat", "baba", "atan"],
    ["baba", "abat", "baba", "atal"],
]

These are the standard examples for this problem.

First Thought: Try Every Arrangement

If each word has length n, a word square needs exactly n words.

A brute force method would try every sequence of n words and check whether it forms a square.

With m words, that can be:

m^n

candidates, because words may be reused.

This grows quickly.

We need to prune bad choices early.

Key Insight

Build the square row by row.

Suppose we already chose these rows:

wall
area

Now we are choosing the third row, index 2.

The third row must make the third column correct so far.

Look at column 2 from previous rows:

wall[2] = l
area[2] = e

So the next word must start with:

"le"

Only words with prefix "le" can work.

This is the main pruning rule.

At row r, the next word must have prefix:

square[0][r] + square[1][r] + ... + square[r - 1][r]

If no word has that prefix, we stop exploring that branch immediately.

Prefix Lookup

We need to quickly find all words starting with a given prefix.

A simple way is to build a dictionary:

prefix -> list of words with that prefix

For each word, insert every prefix.

For example, for:

"wall"

we store:

""     -> "wall"
"w"    -> "wall"
"wa"   -> "wall"
"wal"  -> "wall"
"wall" -> "wall"

Then during backtracking, we can get candidates in constant or near-constant dictionary time.

Algorithm

Let:

n = len(words[0])

Build a prefix map:

prefix_map[prefix] = all words starting with prefix

Then backtrack.

The recursive state is:

square

the list of rows chosen so far.

If:

len(square) == n

then the square is complete, so add a copy to the answer.

Otherwise:

  1. Let r = len(square).
  2. Build the required prefix for row r.
  3. Get all candidate words with that prefix.
  4. Try each candidate:
    • Append it.
    • Recurse.
    • Remove it.

Correctness

The algorithm only adds a word at row r if it matches the prefix forced by the first r rows.

That prefix condition ensures:

square[i][r] == candidate[i]

for every previous row i.

So after appending the candidate, all row-column equalities involving the new row and previous rows remain valid.

When the square reaches size n, every row has length n, and every row was added while satisfying all required prefix constraints. Therefore, for every pair of indices (i, j), the character at row i, column j, matches the character at row j, column i. The completed square is valid.

Conversely, consider any valid word square. Its first row is tried by the algorithm. At every later row, the valid square’s next word has exactly the prefix required by the previous rows, so it appears among the candidates and is tried. Therefore, the algorithm eventually generates every valid word square.

Thus the algorithm returns exactly all valid word squares.

Complexity

MetricValueWhy
Preprocessing timeO(mn^2)For each of m words, we build prefixes up to length n
Search timeOutput-sensitiveBacktracking explores only prefix-compatible branches
SpaceO(mn^2 + n)Prefix map plus recursion path

Here:

SymbolMeaning
mNumber of words
nLength of each word

The search cost depends on how many prefix-compatible partial squares exist. The prefix map reduces branching heavily compared with trying every word at every row.

Implementation

from collections import defaultdict
from typing import List

class Solution:
    def wordSquares(self, words: List[str]) -> List[List[str]]:
        n = len(words[0])

        prefix_map = defaultdict(list)

        for word in words:
            for length in range(n + 1):
                prefix = word[:length]
                prefix_map[prefix].append(word)

        answer = []
        square = []

        def backtrack() -> None:
            if len(square) == n:
                answer.append(square[:])
                return

            row = len(square)

            prefix = []

            for word in square:
                prefix.append(word[row])

            prefix = "".join(prefix)

            for candidate in prefix_map[prefix]:
                square.append(candidate)
                backtrack()
                square.pop()

        backtrack()

        return answer

Code Explanation

We get the square size:

n = len(words[0])

Since all words have the same length, a complete word square must contain exactly n rows.

We build the prefix map:

prefix_map = defaultdict(list)

For every word, we store every prefix:

for length in range(n + 1):
    prefix = word[:length]
    prefix_map[prefix].append(word)

The empty prefix is included. This lets the first row be any word.

The backtracking state is:

square = []

When the square has n rows, it is complete:

if len(square) == n:
    answer.append(square[:])
    return

For the next row, we compute the required prefix.

If we are filling row row, then the next word must match column row so far:

for word in square:
    prefix.append(word[row])

Then we try every word with that prefix:

for candidate in prefix_map[prefix]:

Append, recurse, and backtrack:

square.append(candidate)
backtrack()
square.pop()

Finally, return all generated squares.

Trie Version

A trie can also support prefix lookup.

Each trie node stores all words that pass through that prefix.

from typing import List

class TrieNode:
    def __init__(self):
        self.children = {}
        self.words = []

class Solution:
    def wordSquares(self, words: List[str]) -> List[List[str]]:
        n = len(words[0])
        root = TrieNode()

        for word in words:
            node = root
            node.words.append(word)

            for ch in word:
                if ch not in node.children:
                    node.children[ch] = TrieNode()

                node = node.children[ch]
                node.words.append(word)

        def find_by_prefix(prefix: str) -> list[str]:
            node = root

            for ch in prefix:
                if ch not in node.children:
                    return []

                node = node.children[ch]

            return node.words

        answer = []
        square = []

        def backtrack() -> None:
            if len(square) == n:
                answer.append(square[:])
                return

            row = len(square)
            prefix = "".join(word[row] for word in square)

            for candidate in find_by_prefix(prefix):
                square.append(candidate)
                backtrack()
                square.pop()

        backtrack()

        return answer

The prefix-map version is usually simpler in Python because the maximum word length is small.

Testing

def normalize(squares):
    return sorted(tuple(square) for square in squares)

def test_word_squares():
    s = Solution()

    result1 = s.wordSquares(["area", "lead", "wall", "lady", "ball"])

    expected1 = [
        ["ball", "area", "lead", "lady"],
        ["wall", "area", "lead", "lady"],
    ]

    assert normalize(result1) == normalize(expected1)

    result2 = s.wordSquares(["abat", "baba", "atan", "atal"])

    expected2 = [
        ["baba", "abat", "baba", "atan"],
        ["baba", "abat", "baba", "atal"],
    ]

    assert normalize(result2) == normalize(expected2)

    assert normalize(s.wordSquares(["a"])) == normalize([["a"]])

    result4 = s.wordSquares(["aa", "ab", "ba", "bb"])

    for square in result4:
        for i in range(len(square)):
            row = square[i]
            col = "".join(square[r][i] for r in range(len(square)))
            assert row == col

    print("all tests passed")

Test Notes

TestWhy
Standard example 1Checks multiple valid squares
Standard example 2Checks word reuse inside a square
Single-character wordMinimum input
All two-letter combinationsValidates every returned square by row-column equality