# LeetCode 425: Word Squares

## Problem Restatement

We are given an array of unique strings:

```python
words
```

All words have the same length.

Return all word squares that can be built from these words.

The same word may be used multiple times.

A word square is a list of words where the `k`th row and the `k`th column read the same string.

The answer may be returned in any order.

The source statement says all input words are unique, all words have the same length, and the same word can be used multiple times in different word squares. The constraints include `1 <= words.length <= 1000` and word length up to `4` in the current LeetCode statement.

## Input and Output

| Item | Meaning |
|---|---|
| Input | A list of unique equal-length words |
| Output | All valid word squares |
| Word length | All words have the same length |
| Reuse rule | The same word may be used multiple times |
| Order | Output order does not matter |

Example function shape:

```python
def wordSquares(words: list[str]) -> list[list[str]]:
    ...
```

## Examples

Example 1:

```python
words = ["area", "lead", "wall", "lady", "ball"]
```

One valid output is:

```python
[
    ["ball", "area", "lead", "lady"],
    ["wall", "area", "lead", "lady"],
]
```

The first square is:

```text
b a l l
a r e a
l e a d
l a d y
```

Rows:

```text
ball
area
lead
lady
```

Columns:

```text
ball
area
lead
lady
```

They match.

Example 2:

```python
words = ["abat", "baba", "atan", "atal"]
```

One valid output is:

```python
[
    ["baba", "abat", "baba", "atan"],
    ["baba", "abat", "baba", "atal"],
]
```

These are the standard examples for this problem.

## First Thought: Try Every Arrangement

If each word has length `n`, a word square needs exactly `n` words.

A brute force method would try every sequence of `n` words and check whether it forms a square.

With `m` words, that can be:

```python
m^n
```

candidates, because words may be reused.

This grows quickly.

We need to prune bad choices early.

## Key Insight

Build the square row by row.

Suppose we already chose these rows:

```text
wall
area
```

Now we are choosing the third row, index `2`.

The third row must make the third column correct so far.

Look at column `2` from previous rows:

```text
wall[2] = l
area[2] = e
```

So the next word must start with:

```python
"le"
```

Only words with prefix `"le"` can work.

This is the main pruning rule.

At row `r`, the next word must have prefix:

```python
square[0][r] + square[1][r] + ... + square[r - 1][r]
```

If no word has that prefix, we stop exploring that branch immediately.

## Prefix Lookup

We need to quickly find all words starting with a given prefix.

A simple way is to build a dictionary:

```python
prefix -> list of words with that prefix
```

For each word, insert every prefix.

For example, for:

```python
"wall"
```

we store:

```text
""     -> "wall"
"w"    -> "wall"
"wa"   -> "wall"
"wal"  -> "wall"
"wall" -> "wall"
```

Then during backtracking, we can get candidates in constant or near-constant dictionary time.

## Algorithm

Let:

```python
n = len(words[0])
```

Build a prefix map:

```python
prefix_map[prefix] = all words starting with prefix
```

Then backtrack.

The recursive state is:

```python
square
```

the list of rows chosen so far.

If:

```python
len(square) == n
```

then the square is complete, so add a copy to the answer.

Otherwise:

1. Let `r = len(square)`.
2. Build the required prefix for row `r`.
3. Get all candidate words with that prefix.
4. Try each candidate:
   - Append it.
   - Recurse.
   - Remove it.

## Correctness

The algorithm only adds a word at row `r` if it matches the prefix forced by the first `r` rows.

That prefix condition ensures:

```python
square[i][r] == candidate[i]
```

for every previous row `i`.

So after appending the candidate, all row-column equalities involving the new row and previous rows remain valid.

When the square reaches size `n`, every row has length `n`, and every row was added while satisfying all required prefix constraints. Therefore, for every pair of indices `(i, j)`, the character at row `i`, column `j`, matches the character at row `j`, column `i`. The completed square is valid.

Conversely, consider any valid word square. Its first row is tried by the algorithm. At every later row, the valid square’s next word has exactly the prefix required by the previous rows, so it appears among the candidates and is tried. Therefore, the algorithm eventually generates every valid word square.

Thus the algorithm returns exactly all valid word squares.

## Complexity

| Metric | Value | Why |
|---|---|---|
| Preprocessing time | `O(mn^2)` | For each of `m` words, we build prefixes up to length `n` |
| Search time | Output-sensitive | Backtracking explores only prefix-compatible branches |
| Space | `O(mn^2 + n)` | Prefix map plus recursion path |

Here:

| Symbol | Meaning |
|---|---|
| `m` | Number of words |
| `n` | Length of each word |

The search cost depends on how many prefix-compatible partial squares exist. The prefix map reduces branching heavily compared with trying every word at every row.

## Implementation

```python
from collections import defaultdict
from typing import List

class Solution:
    def wordSquares(self, words: List[str]) -> List[List[str]]:
        n = len(words[0])

        prefix_map = defaultdict(list)

        for word in words:
            for length in range(n + 1):
                prefix = word[:length]
                prefix_map[prefix].append(word)

        answer = []
        square = []

        def backtrack() -> None:
            if len(square) == n:
                answer.append(square[:])
                return

            row = len(square)

            prefix = []

            for word in square:
                prefix.append(word[row])

            prefix = "".join(prefix)

            for candidate in prefix_map[prefix]:
                square.append(candidate)
                backtrack()
                square.pop()

        backtrack()

        return answer
```

## Code Explanation

We get the square size:

```python
n = len(words[0])
```

Since all words have the same length, a complete word square must contain exactly `n` rows.

We build the prefix map:

```python
prefix_map = defaultdict(list)
```

For every word, we store every prefix:

```python
for length in range(n + 1):
    prefix = word[:length]
    prefix_map[prefix].append(word)
```

The empty prefix is included. This lets the first row be any word.

The backtracking state is:

```python
square = []
```

When the square has `n` rows, it is complete:

```python
if len(square) == n:
    answer.append(square[:])
    return
```

For the next row, we compute the required prefix.

If we are filling row `row`, then the next word must match column `row` so far:

```python
for word in square:
    prefix.append(word[row])
```

Then we try every word with that prefix:

```python
for candidate in prefix_map[prefix]:
```

Append, recurse, and backtrack:

```python
square.append(candidate)
backtrack()
square.pop()
```

Finally, return all generated squares.

## Trie Version

A trie can also support prefix lookup.

Each trie node stores all words that pass through that prefix.

```python
from typing import List

class TrieNode:
    def __init__(self):
        self.children = {}
        self.words = []

class Solution:
    def wordSquares(self, words: List[str]) -> List[List[str]]:
        n = len(words[0])
        root = TrieNode()

        for word in words:
            node = root
            node.words.append(word)

            for ch in word:
                if ch not in node.children:
                    node.children[ch] = TrieNode()

                node = node.children[ch]
                node.words.append(word)

        def find_by_prefix(prefix: str) -> list[str]:
            node = root

            for ch in prefix:
                if ch not in node.children:
                    return []

                node = node.children[ch]

            return node.words

        answer = []
        square = []

        def backtrack() -> None:
            if len(square) == n:
                answer.append(square[:])
                return

            row = len(square)
            prefix = "".join(word[row] for word in square)

            for candidate in find_by_prefix(prefix):
                square.append(candidate)
                backtrack()
                square.pop()

        backtrack()

        return answer
```

The prefix-map version is usually simpler in Python because the maximum word length is small.

## Testing

```python
def normalize(squares):
    return sorted(tuple(square) for square in squares)

def test_word_squares():
    s = Solution()

    result1 = s.wordSquares(["area", "lead", "wall", "lady", "ball"])

    expected1 = [
        ["ball", "area", "lead", "lady"],
        ["wall", "area", "lead", "lady"],
    ]

    assert normalize(result1) == normalize(expected1)

    result2 = s.wordSquares(["abat", "baba", "atan", "atal"])

    expected2 = [
        ["baba", "abat", "baba", "atan"],
        ["baba", "abat", "baba", "atal"],
    ]

    assert normalize(result2) == normalize(expected2)

    assert normalize(s.wordSquares(["a"])) == normalize([["a"]])

    result4 = s.wordSquares(["aa", "ab", "ba", "bb"])

    for square in result4:
        for i in range(len(square)):
            row = square[i]
            col = "".join(square[r][i] for r in range(len(square)))
            assert row == col

    print("all tests passed")
```

## Test Notes

| Test | Why |
|---|---|
| Standard example 1 | Checks multiple valid squares |
| Standard example 2 | Checks word reuse inside a square |
| Single-character word | Minimum input |
| All two-letter combinations | Validates every returned square by row-column equality |

