A clear explanation of building all word squares using backtracking with prefix pruning.
Problem Restatement
We are given an array of unique strings:
wordsAll words have the same length.
Return all word squares that can be built from these words.
The same word may be used multiple times.
A word square is a list of words where the kth row and the kth column read the same string.
The answer may be returned in any order.
The source statement says all input words are unique, all words have the same length, and the same word can be used multiple times in different word squares. The constraints include 1 <= words.length <= 1000 and word length up to 4 in the current LeetCode statement.
Input and Output
| Item | Meaning |
|---|---|
| Input | A list of unique equal-length words |
| Output | All valid word squares |
| Word length | All words have the same length |
| Reuse rule | The same word may be used multiple times |
| Order | Output order does not matter |
Example function shape:
def wordSquares(words: list[str]) -> list[list[str]]:
...Examples
Example 1:
words = ["area", "lead", "wall", "lady", "ball"]One valid output is:
[
["ball", "area", "lead", "lady"],
["wall", "area", "lead", "lady"],
]The first square is:
b a l l
a r e a
l e a d
l a d yRows:
ball
area
lead
ladyColumns:
ball
area
lead
ladyThey match.
Example 2:
words = ["abat", "baba", "atan", "atal"]One valid output is:
[
["baba", "abat", "baba", "atan"],
["baba", "abat", "baba", "atal"],
]These are the standard examples for this problem.
First Thought: Try Every Arrangement
If each word has length n, a word square needs exactly n words.
A brute force method would try every sequence of n words and check whether it forms a square.
With m words, that can be:
m^ncandidates, because words may be reused.
This grows quickly.
We need to prune bad choices early.
Key Insight
Build the square row by row.
Suppose we already chose these rows:
wall
areaNow we are choosing the third row, index 2.
The third row must make the third column correct so far.
Look at column 2 from previous rows:
wall[2] = l
area[2] = eSo the next word must start with:
"le"Only words with prefix "le" can work.
This is the main pruning rule.
At row r, the next word must have prefix:
square[0][r] + square[1][r] + ... + square[r - 1][r]If no word has that prefix, we stop exploring that branch immediately.
Prefix Lookup
We need to quickly find all words starting with a given prefix.
A simple way is to build a dictionary:
prefix -> list of words with that prefixFor each word, insert every prefix.
For example, for:
"wall"we store:
"" -> "wall"
"w" -> "wall"
"wa" -> "wall"
"wal" -> "wall"
"wall" -> "wall"Then during backtracking, we can get candidates in constant or near-constant dictionary time.
Algorithm
Let:
n = len(words[0])Build a prefix map:
prefix_map[prefix] = all words starting with prefixThen backtrack.
The recursive state is:
squarethe list of rows chosen so far.
If:
len(square) == nthen the square is complete, so add a copy to the answer.
Otherwise:
- Let
r = len(square). - Build the required prefix for row
r. - Get all candidate words with that prefix.
- Try each candidate:
- Append it.
- Recurse.
- Remove it.
Correctness
The algorithm only adds a word at row r if it matches the prefix forced by the first r rows.
That prefix condition ensures:
square[i][r] == candidate[i]for every previous row i.
So after appending the candidate, all row-column equalities involving the new row and previous rows remain valid.
When the square reaches size n, every row has length n, and every row was added while satisfying all required prefix constraints. Therefore, for every pair of indices (i, j), the character at row i, column j, matches the character at row j, column i. The completed square is valid.
Conversely, consider any valid word square. Its first row is tried by the algorithm. At every later row, the valid square’s next word has exactly the prefix required by the previous rows, so it appears among the candidates and is tried. Therefore, the algorithm eventually generates every valid word square.
Thus the algorithm returns exactly all valid word squares.
Complexity
| Metric | Value | Why |
|---|---|---|
| Preprocessing time | O(mn^2) | For each of m words, we build prefixes up to length n |
| Search time | Output-sensitive | Backtracking explores only prefix-compatible branches |
| Space | O(mn^2 + n) | Prefix map plus recursion path |
Here:
| Symbol | Meaning |
|---|---|
m | Number of words |
n | Length of each word |
The search cost depends on how many prefix-compatible partial squares exist. The prefix map reduces branching heavily compared with trying every word at every row.
Implementation
from collections import defaultdict
from typing import List
class Solution:
def wordSquares(self, words: List[str]) -> List[List[str]]:
n = len(words[0])
prefix_map = defaultdict(list)
for word in words:
for length in range(n + 1):
prefix = word[:length]
prefix_map[prefix].append(word)
answer = []
square = []
def backtrack() -> None:
if len(square) == n:
answer.append(square[:])
return
row = len(square)
prefix = []
for word in square:
prefix.append(word[row])
prefix = "".join(prefix)
for candidate in prefix_map[prefix]:
square.append(candidate)
backtrack()
square.pop()
backtrack()
return answerCode Explanation
We get the square size:
n = len(words[0])Since all words have the same length, a complete word square must contain exactly n rows.
We build the prefix map:
prefix_map = defaultdict(list)For every word, we store every prefix:
for length in range(n + 1):
prefix = word[:length]
prefix_map[prefix].append(word)The empty prefix is included. This lets the first row be any word.
The backtracking state is:
square = []When the square has n rows, it is complete:
if len(square) == n:
answer.append(square[:])
returnFor the next row, we compute the required prefix.
If we are filling row row, then the next word must match column row so far:
for word in square:
prefix.append(word[row])Then we try every word with that prefix:
for candidate in prefix_map[prefix]:Append, recurse, and backtrack:
square.append(candidate)
backtrack()
square.pop()Finally, return all generated squares.
Trie Version
A trie can also support prefix lookup.
Each trie node stores all words that pass through that prefix.
from typing import List
class TrieNode:
def __init__(self):
self.children = {}
self.words = []
class Solution:
def wordSquares(self, words: List[str]) -> List[List[str]]:
n = len(words[0])
root = TrieNode()
for word in words:
node = root
node.words.append(word)
for ch in word:
if ch not in node.children:
node.children[ch] = TrieNode()
node = node.children[ch]
node.words.append(word)
def find_by_prefix(prefix: str) -> list[str]:
node = root
for ch in prefix:
if ch not in node.children:
return []
node = node.children[ch]
return node.words
answer = []
square = []
def backtrack() -> None:
if len(square) == n:
answer.append(square[:])
return
row = len(square)
prefix = "".join(word[row] for word in square)
for candidate in find_by_prefix(prefix):
square.append(candidate)
backtrack()
square.pop()
backtrack()
return answerThe prefix-map version is usually simpler in Python because the maximum word length is small.
Testing
def normalize(squares):
return sorted(tuple(square) for square in squares)
def test_word_squares():
s = Solution()
result1 = s.wordSquares(["area", "lead", "wall", "lady", "ball"])
expected1 = [
["ball", "area", "lead", "lady"],
["wall", "area", "lead", "lady"],
]
assert normalize(result1) == normalize(expected1)
result2 = s.wordSquares(["abat", "baba", "atan", "atal"])
expected2 = [
["baba", "abat", "baba", "atan"],
["baba", "abat", "baba", "atal"],
]
assert normalize(result2) == normalize(expected2)
assert normalize(s.wordSquares(["a"])) == normalize([["a"]])
result4 = s.wordSquares(["aa", "ab", "ba", "bb"])
for square in result4:
for i in range(len(square)):
row = square[i]
col = "".join(square[r][i] for r in range(len(square)))
assert row == col
print("all tests passed")Test Notes
| Test | Why |
|---|---|
| Standard example 1 | Checks multiple valid squares |
| Standard example 2 | Checks word reuse inside a square |
| Single-character word | Minimum input |
| All two-letter combinations | Validates every returned square by row-column equality |