# LeetCode 311: Sparse Matrix Multiplication

## Problem Restatement

We are given two sparse matrices `mat1` and `mat2`.

A sparse matrix contains many zero values.

We need to return the matrix product:

```python
mat1 * mat2
```

If `mat1` has size `m x k`, and `mat2` has size `k x n`, then the result has size `m x n`.

For each result cell:

```python
answer[i][j]
```

we compute the dot product of row `i` from `mat1` and column `j` from `mat2`.

The problem asks us to multiply the matrices efficiently by taking advantage of sparsity. Public mirrors describe the task as returning the result of multiplying two sparse matrices.

## Input and Output

| Item | Meaning |
|---|---|
| Input | Two integer matrices `mat1` and `mat2` |
| `mat1` size | `m x k` |
| `mat2` size | `k x n` |
| Output | Product matrix of size `m x n` |
| Sparse property | Many values are zero |

Function shape:

```python
def multiply(
    mat1: list[list[int]],
    mat2: list[list[int]],
) -> list[list[int]]:
    ...
```

## Examples

Example:

```python
mat1 = [
    [1, 0, 0],
    [-1, 0, 3],
]

mat2 = [
    [7, 0, 0],
    [0, 0, 0],
    [0, 0, 1],
]
```

The result has size `2 x 3`.

For cell `(0, 0)`:

```python
1 * 7 + 0 * 0 + 0 * 0 = 7
```

For cell `(0, 1)`:

```python
1 * 0 + 0 * 0 + 0 * 0 = 0
```

For cell `(0, 2)`:

```python
1 * 0 + 0 * 0 + 0 * 1 = 0
```

For row `1`:

```python
[-1, 0, 3]
```

Multiplying against columns of `mat2` gives:

```python
[-7, 0, 3]
```

Output:

```python
[
    [7, 0, 0],
    [-7, 0, 3],
]
```

## First Thought: Standard Matrix Multiplication

The usual formula is:

```python
answer[i][j] = sum(mat1[i][t] * mat2[t][j] for t in range(k))
```

A direct implementation uses three loops:

```python
for i in range(m):
    for j in range(n):
        for t in range(k):
            answer[i][j] += mat1[i][t] * mat2[t][j]
```

This is correct.

But it performs work even when one of the multiplied values is zero.

For sparse matrices, most of these products may be useless.

## Key Insight

Only non-zero products can affect the answer.

When we multiply:

```python
mat1[i][t] * mat2[t][j]
```

the product contributes nothing if:

```python
mat1[i][t] == 0
```

or:

```python
mat2[t][j] == 0
```

So instead of iterating over every possible triple `(i, t, j)`, we can skip zero entries.

A useful loop order is:

```text
row i in mat1
shared index t
column j in mat2
```

When `mat1[i][t]` is zero, skip the whole row contribution from `mat2[t]`.

When `mat2[t][j]` is zero, skip that single product.

## Algorithm

Let:

```python
m = len(mat1)
k = len(mat1[0])
n = len(mat2[0])
```

Create:

```python
answer = [[0] * n for _ in range(m)]
```

Then:

1. Iterate through each row `i` of `mat1`.
2. Iterate through each shared index `t`.
3. If `mat1[i][t] == 0`, skip it.
4. Otherwise, iterate through row `t` of `mat2`.
5. If `mat2[t][j] == 0`, skip it.
6. Add the product to `answer[i][j]`.

The update is:

```python
answer[i][j] += mat1[i][t] * mat2[t][j]
```

## Correctness

For matrix multiplication, each output cell must equal:

```python
sum(mat1[i][t] * mat2[t][j] for t in range(k))
```

The algorithm considers every possible pair `(i, t)` from `mat1` and every possible column `j` in row `t` of `mat2`.

If either factor is zero, the product contributes zero to the sum, so skipping it does not change the result.

If both factors are non-zero, the algorithm adds exactly:

```python
mat1[i][t] * mat2[t][j]
```

to `answer[i][j]`.

Therefore, for every cell `(i, j)`, the algorithm adds exactly all non-zero contributions from the standard multiplication formula and omits only zero contributions. The final matrix is exactly the product of `mat1` and `mat2`.

## Complexity

Let:

| Symbol | Meaning |
|---|---|
| `m` | Number of rows in `mat1` |
| `k` | Number of columns in `mat1`, also rows in `mat2` |
| `n` | Number of columns in `mat2` |

Worst case:

| Metric | Value | Why |
|---|---:|---|
| Time | `O(mkn)` | If there are no zeros, we do normal multiplication |
| Space | `O(mn)` | Output matrix |

With sparse inputs, the loop skips many products.

A sharper practical view is:

```python
sum over non-zero mat1[i][t] of number of non-zero values in mat2[t]
```

That is the real amount of useful multiplication work.

## Implementation

```python
class Solution:
    def multiply(
        self,
        mat1: list[list[int]],
        mat2: list[list[int]],
    ) -> list[list[int]]:

        m = len(mat1)
        k = len(mat1[0])
        n = len(mat2[0])

        answer = [[0] * n for _ in range(m)]

        for i in range(m):
            for t in range(k):
                if mat1[i][t] == 0:
                    continue

                for j in range(n):
                    if mat2[t][j] == 0:
                        continue

                    answer[i][j] += mat1[i][t] * mat2[t][j]

        return answer
```

## Code Explanation

We first read the dimensions.

```python
m = len(mat1)
k = len(mat1[0])
n = len(mat2[0])
```

The result matrix has `m` rows and `n` columns.

```python
answer = [[0] * n for _ in range(m)]
```

Then we scan each possible non-zero entry in `mat1`.

```python
for i in range(m):
    for t in range(k):
```

If `mat1[i][t]` is zero, then it contributes nothing to the whole output row for that `t`.

```python
if mat1[i][t] == 0:
    continue
```

Otherwise, it can contribute to every column `j` where `mat2[t][j]` is non-zero.

```python
for j in range(n):
```

Again, skip zero values from `mat2`.

```python
if mat2[t][j] == 0:
    continue
```

When both values are non-zero, add their product.

```python
answer[i][j] += mat1[i][t] * mat2[t][j]
```

Finally return the completed matrix.

```python
return answer
```

## More Sparse-Friendly Implementation

We can compress `mat2` by storing only non-zero values in each row.

This avoids scanning all `n` columns for each non-zero value in `mat1`.

```python
class Solution:
    def multiply(
        self,
        mat1: list[list[int]],
        mat2: list[list[int]],
    ) -> list[list[int]]:

        m = len(mat1)
        k = len(mat1[0])
        n = len(mat2[0])

        rows2 = []

        for r in range(k):
            row = []

            for c in range(n):
                if mat2[r][c] != 0:
                    row.append((c, mat2[r][c]))

            rows2.append(row)

        answer = [[0] * n for _ in range(m)]

        for i in range(m):
            for t in range(k):
                if mat1[i][t] == 0:
                    continue

                for j, value2 in rows2[t]:
                    answer[i][j] += mat1[i][t] * value2

        return answer
```

This version is usually better when `mat2` is also sparse.

## Testing

```python
def run_tests():
    s = Solution()

    assert s.multiply(
        [
            [1, 0, 0],
            [-1, 0, 3],
        ],
        [
            [7, 0, 0],
            [0, 0, 0],
            [0, 0, 1],
        ],
    ) == [
        [7, 0, 0],
        [-7, 0, 3],
    ]

    assert s.multiply(
        [[0]],
        [[5]],
    ) == [
        [0],
    ]

    assert s.multiply(
        [[2]],
        [[3]],
    ) == [
        [6],
    ]

    assert s.multiply(
        [
            [1, 2],
            [3, 4],
        ],
        [
            [5, 6],
            [7, 8],
        ],
    ) == [
        [19, 22],
        [43, 50],
    ]

    assert s.multiply(
        [
            [1, 0, 2],
        ],
        [
            [3],
            [4],
            [5],
        ],
    ) == [
        [13],
    ]

    print("all tests passed")

run_tests()
```

Test meaning:

| Test | Why |
|---|---|
| Sparse example | Checks zero skipping and normal multiplication |
| Zero single cell | Product remains zero |
| Non-zero single cell | Smallest non-zero multiplication |
| Dense matrix | Confirms correctness when no sparsity exists |
| Row times column | Checks rectangular dimensions |

