Skip to content

LeetCode 311: Sparse Matrix Multiplication

A clear explanation of Sparse Matrix Multiplication using non-zero entries to avoid wasted work.

Problem Restatement

We are given two sparse matrices mat1 and mat2.

A sparse matrix contains many zero values.

We need to return the matrix product:

mat1 * mat2

If mat1 has size m x k, and mat2 has size k x n, then the result has size m x n.

For each result cell:

answer[i][j]

we compute the dot product of row i from mat1 and column j from mat2.

The problem asks us to multiply the matrices efficiently by taking advantage of sparsity. Public mirrors describe the task as returning the result of multiplying two sparse matrices.

Input and Output

ItemMeaning
InputTwo integer matrices mat1 and mat2
mat1 sizem x k
mat2 sizek x n
OutputProduct matrix of size m x n
Sparse propertyMany values are zero

Function shape:

def multiply(
    mat1: list[list[int]],
    mat2: list[list[int]],
) -> list[list[int]]:
    ...

Examples

Example:

mat1 = [
    [1, 0, 0],
    [-1, 0, 3],
]

mat2 = [
    [7, 0, 0],
    [0, 0, 0],
    [0, 0, 1],
]

The result has size 2 x 3.

For cell (0, 0):

1 * 7 + 0 * 0 + 0 * 0 = 7

For cell (0, 1):

1 * 0 + 0 * 0 + 0 * 0 = 0

For cell (0, 2):

1 * 0 + 0 * 0 + 0 * 1 = 0

For row 1:

[-1, 0, 3]

Multiplying against columns of mat2 gives:

[-7, 0, 3]

Output:

[
    [7, 0, 0],
    [-7, 0, 3],
]

First Thought: Standard Matrix Multiplication

The usual formula is:

answer[i][j] = sum(mat1[i][t] * mat2[t][j] for t in range(k))

A direct implementation uses three loops:

for i in range(m):
    for j in range(n):
        for t in range(k):
            answer[i][j] += mat1[i][t] * mat2[t][j]

This is correct.

But it performs work even when one of the multiplied values is zero.

For sparse matrices, most of these products may be useless.

Key Insight

Only non-zero products can affect the answer.

When we multiply:

mat1[i][t] * mat2[t][j]

the product contributes nothing if:

mat1[i][t] == 0

or:

mat2[t][j] == 0

So instead of iterating over every possible triple (i, t, j), we can skip zero entries.

A useful loop order is:

row i in mat1
shared index t
column j in mat2

When mat1[i][t] is zero, skip the whole row contribution from mat2[t].

When mat2[t][j] is zero, skip that single product.

Algorithm

Let:

m = len(mat1)
k = len(mat1[0])
n = len(mat2[0])

Create:

answer = [[0] * n for _ in range(m)]

Then:

  1. Iterate through each row i of mat1.
  2. Iterate through each shared index t.
  3. If mat1[i][t] == 0, skip it.
  4. Otherwise, iterate through row t of mat2.
  5. If mat2[t][j] == 0, skip it.
  6. Add the product to answer[i][j].

The update is:

answer[i][j] += mat1[i][t] * mat2[t][j]

Correctness

For matrix multiplication, each output cell must equal:

sum(mat1[i][t] * mat2[t][j] for t in range(k))

The algorithm considers every possible pair (i, t) from mat1 and every possible column j in row t of mat2.

If either factor is zero, the product contributes zero to the sum, so skipping it does not change the result.

If both factors are non-zero, the algorithm adds exactly:

mat1[i][t] * mat2[t][j]

to answer[i][j].

Therefore, for every cell (i, j), the algorithm adds exactly all non-zero contributions from the standard multiplication formula and omits only zero contributions. The final matrix is exactly the product of mat1 and mat2.

Complexity

Let:

SymbolMeaning
mNumber of rows in mat1
kNumber of columns in mat1, also rows in mat2
nNumber of columns in mat2

Worst case:

MetricValueWhy
TimeO(mkn)If there are no zeros, we do normal multiplication
SpaceO(mn)Output matrix

With sparse inputs, the loop skips many products.

A sharper practical view is:

sum over non-zero mat1[i][t] of number of non-zero values in mat2[t]

That is the real amount of useful multiplication work.

Implementation

class Solution:
    def multiply(
        self,
        mat1: list[list[int]],
        mat2: list[list[int]],
    ) -> list[list[int]]:

        m = len(mat1)
        k = len(mat1[0])
        n = len(mat2[0])

        answer = [[0] * n for _ in range(m)]

        for i in range(m):
            for t in range(k):
                if mat1[i][t] == 0:
                    continue

                for j in range(n):
                    if mat2[t][j] == 0:
                        continue

                    answer[i][j] += mat1[i][t] * mat2[t][j]

        return answer

Code Explanation

We first read the dimensions.

m = len(mat1)
k = len(mat1[0])
n = len(mat2[0])

The result matrix has m rows and n columns.

answer = [[0] * n for _ in range(m)]

Then we scan each possible non-zero entry in mat1.

for i in range(m):
    for t in range(k):

If mat1[i][t] is zero, then it contributes nothing to the whole output row for that t.

if mat1[i][t] == 0:
    continue

Otherwise, it can contribute to every column j where mat2[t][j] is non-zero.

for j in range(n):

Again, skip zero values from mat2.

if mat2[t][j] == 0:
    continue

When both values are non-zero, add their product.

answer[i][j] += mat1[i][t] * mat2[t][j]

Finally return the completed matrix.

return answer

More Sparse-Friendly Implementation

We can compress mat2 by storing only non-zero values in each row.

This avoids scanning all n columns for each non-zero value in mat1.

class Solution:
    def multiply(
        self,
        mat1: list[list[int]],
        mat2: list[list[int]],
    ) -> list[list[int]]:

        m = len(mat1)
        k = len(mat1[0])
        n = len(mat2[0])

        rows2 = []

        for r in range(k):
            row = []

            for c in range(n):
                if mat2[r][c] != 0:
                    row.append((c, mat2[r][c]))

            rows2.append(row)

        answer = [[0] * n for _ in range(m)]

        for i in range(m):
            for t in range(k):
                if mat1[i][t] == 0:
                    continue

                for j, value2 in rows2[t]:
                    answer[i][j] += mat1[i][t] * value2

        return answer

This version is usually better when mat2 is also sparse.

Testing

def run_tests():
    s = Solution()

    assert s.multiply(
        [
            [1, 0, 0],
            [-1, 0, 3],
        ],
        [
            [7, 0, 0],
            [0, 0, 0],
            [0, 0, 1],
        ],
    ) == [
        [7, 0, 0],
        [-7, 0, 3],
    ]

    assert s.multiply(
        [[0]],
        [[5]],
    ) == [
        [0],
    ]

    assert s.multiply(
        [[2]],
        [[3]],
    ) == [
        [6],
    ]

    assert s.multiply(
        [
            [1, 2],
            [3, 4],
        ],
        [
            [5, 6],
            [7, 8],
        ],
    ) == [
        [19, 22],
        [43, 50],
    ]

    assert s.multiply(
        [
            [1, 0, 2],
        ],
        [
            [3],
            [4],
            [5],
        ],
    ) == [
        [13],
    ]

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
Sparse exampleChecks zero skipping and normal multiplication
Zero single cellProduct remains zero
Non-zero single cellSmallest non-zero multiplication
Dense matrixConfirms correctness when no sparsity exists
Row times columnChecks rectangular dimensions