A clear explanation of Sparse Matrix Multiplication using non-zero entries to avoid wasted work.
Problem Restatement
We are given two sparse matrices mat1 and mat2.
A sparse matrix contains many zero values.
We need to return the matrix product:
mat1 * mat2If mat1 has size m x k, and mat2 has size k x n, then the result has size m x n.
For each result cell:
answer[i][j]we compute the dot product of row i from mat1 and column j from mat2.
The problem asks us to multiply the matrices efficiently by taking advantage of sparsity. Public mirrors describe the task as returning the result of multiplying two sparse matrices.
Input and Output
| Item | Meaning |
|---|---|
| Input | Two integer matrices mat1 and mat2 |
mat1 size | m x k |
mat2 size | k x n |
| Output | Product matrix of size m x n |
| Sparse property | Many values are zero |
Function shape:
def multiply(
mat1: list[list[int]],
mat2: list[list[int]],
) -> list[list[int]]:
...Examples
Example:
mat1 = [
[1, 0, 0],
[-1, 0, 3],
]
mat2 = [
[7, 0, 0],
[0, 0, 0],
[0, 0, 1],
]The result has size 2 x 3.
For cell (0, 0):
1 * 7 + 0 * 0 + 0 * 0 = 7For cell (0, 1):
1 * 0 + 0 * 0 + 0 * 0 = 0For cell (0, 2):
1 * 0 + 0 * 0 + 0 * 1 = 0For row 1:
[-1, 0, 3]Multiplying against columns of mat2 gives:
[-7, 0, 3]Output:
[
[7, 0, 0],
[-7, 0, 3],
]First Thought: Standard Matrix Multiplication
The usual formula is:
answer[i][j] = sum(mat1[i][t] * mat2[t][j] for t in range(k))A direct implementation uses three loops:
for i in range(m):
for j in range(n):
for t in range(k):
answer[i][j] += mat1[i][t] * mat2[t][j]This is correct.
But it performs work even when one of the multiplied values is zero.
For sparse matrices, most of these products may be useless.
Key Insight
Only non-zero products can affect the answer.
When we multiply:
mat1[i][t] * mat2[t][j]the product contributes nothing if:
mat1[i][t] == 0or:
mat2[t][j] == 0So instead of iterating over every possible triple (i, t, j), we can skip zero entries.
A useful loop order is:
row i in mat1
shared index t
column j in mat2When mat1[i][t] is zero, skip the whole row contribution from mat2[t].
When mat2[t][j] is zero, skip that single product.
Algorithm
Let:
m = len(mat1)
k = len(mat1[0])
n = len(mat2[0])Create:
answer = [[0] * n for _ in range(m)]Then:
- Iterate through each row
iofmat1. - Iterate through each shared index
t. - If
mat1[i][t] == 0, skip it. - Otherwise, iterate through row
tofmat2. - If
mat2[t][j] == 0, skip it. - Add the product to
answer[i][j].
The update is:
answer[i][j] += mat1[i][t] * mat2[t][j]Correctness
For matrix multiplication, each output cell must equal:
sum(mat1[i][t] * mat2[t][j] for t in range(k))The algorithm considers every possible pair (i, t) from mat1 and every possible column j in row t of mat2.
If either factor is zero, the product contributes zero to the sum, so skipping it does not change the result.
If both factors are non-zero, the algorithm adds exactly:
mat1[i][t] * mat2[t][j]to answer[i][j].
Therefore, for every cell (i, j), the algorithm adds exactly all non-zero contributions from the standard multiplication formula and omits only zero contributions. The final matrix is exactly the product of mat1 and mat2.
Complexity
Let:
| Symbol | Meaning |
|---|---|
m | Number of rows in mat1 |
k | Number of columns in mat1, also rows in mat2 |
n | Number of columns in mat2 |
Worst case:
| Metric | Value | Why |
|---|---|---|
| Time | O(mkn) | If there are no zeros, we do normal multiplication |
| Space | O(mn) | Output matrix |
With sparse inputs, the loop skips many products.
A sharper practical view is:
sum over non-zero mat1[i][t] of number of non-zero values in mat2[t]That is the real amount of useful multiplication work.
Implementation
class Solution:
def multiply(
self,
mat1: list[list[int]],
mat2: list[list[int]],
) -> list[list[int]]:
m = len(mat1)
k = len(mat1[0])
n = len(mat2[0])
answer = [[0] * n for _ in range(m)]
for i in range(m):
for t in range(k):
if mat1[i][t] == 0:
continue
for j in range(n):
if mat2[t][j] == 0:
continue
answer[i][j] += mat1[i][t] * mat2[t][j]
return answerCode Explanation
We first read the dimensions.
m = len(mat1)
k = len(mat1[0])
n = len(mat2[0])The result matrix has m rows and n columns.
answer = [[0] * n for _ in range(m)]Then we scan each possible non-zero entry in mat1.
for i in range(m):
for t in range(k):If mat1[i][t] is zero, then it contributes nothing to the whole output row for that t.
if mat1[i][t] == 0:
continueOtherwise, it can contribute to every column j where mat2[t][j] is non-zero.
for j in range(n):Again, skip zero values from mat2.
if mat2[t][j] == 0:
continueWhen both values are non-zero, add their product.
answer[i][j] += mat1[i][t] * mat2[t][j]Finally return the completed matrix.
return answerMore Sparse-Friendly Implementation
We can compress mat2 by storing only non-zero values in each row.
This avoids scanning all n columns for each non-zero value in mat1.
class Solution:
def multiply(
self,
mat1: list[list[int]],
mat2: list[list[int]],
) -> list[list[int]]:
m = len(mat1)
k = len(mat1[0])
n = len(mat2[0])
rows2 = []
for r in range(k):
row = []
for c in range(n):
if mat2[r][c] != 0:
row.append((c, mat2[r][c]))
rows2.append(row)
answer = [[0] * n for _ in range(m)]
for i in range(m):
for t in range(k):
if mat1[i][t] == 0:
continue
for j, value2 in rows2[t]:
answer[i][j] += mat1[i][t] * value2
return answerThis version is usually better when mat2 is also sparse.
Testing
def run_tests():
s = Solution()
assert s.multiply(
[
[1, 0, 0],
[-1, 0, 3],
],
[
[7, 0, 0],
[0, 0, 0],
[0, 0, 1],
],
) == [
[7, 0, 0],
[-7, 0, 3],
]
assert s.multiply(
[[0]],
[[5]],
) == [
[0],
]
assert s.multiply(
[[2]],
[[3]],
) == [
[6],
]
assert s.multiply(
[
[1, 2],
[3, 4],
],
[
[5, 6],
[7, 8],
],
) == [
[19, 22],
[43, 50],
]
assert s.multiply(
[
[1, 0, 2],
],
[
[3],
[4],
[5],
],
) == [
[13],
]
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
| Sparse example | Checks zero skipping and normal multiplication |
| Zero single cell | Product remains zero |
| Non-zero single cell | Smallest non-zero multiplication |
| Dense matrix | Confirms correctness when no sparsity exists |
| Row times column | Checks rectangular dimensions |