A clear explanation of Reshape the Matrix using index mapping from the original matrix to the reshaped matrix.
Problem Restatement
We are given a matrix mat with m rows and n columns.
We are also given two integers:
| Value | Meaning |
|---|---|
r | Desired number of rows |
c | Desired number of columns |
We need to reshape the matrix into size r x c.
The reshaped matrix must preserve the original row-traversing order.
That means we read the original matrix from left to right, row by row, and place the values into the new matrix from left to right, row by row.
If reshaping is impossible, return the original matrix.
Reshaping is possible only when the total number of elements stays the same. The problem examples include mat = [[1,2],[3,4]], r = 1, c = 4, output [[1,2,3,4]]; and the impossible case r = 2, c = 4, which returns the original matrix.
Input and Output
| Item | Meaning |
|---|---|
| Input | A matrix mat, and integers r and c |
| Output | A reshaped matrix |
| Preserve | Row-major order |
| Return original when | m * n != r * c |
Example function shape:
def matrixReshape(mat: list[list[int]], r: int, c: int) -> list[list[int]]:
...Examples
Example 1:
mat = [
[1, 2],
[3, 4],
]
r = 1
c = 4Read the original matrix in row-major order:
1, 2, 3, 4Place those values into a 1 x 4 matrix:
[[1, 2, 3, 4]]So the answer is:
[[1, 2, 3, 4]]Example 2:
mat = [
[1, 2],
[3, 4],
]
r = 2
c = 4The original matrix has:
2 * 2 = 4elements.
The requested matrix needs:
2 * 4 = 8elements.
Since the element counts do not match, reshaping is impossible.
So we return the original matrix:
[
[1, 2],
[3, 4],
]First Thought: Flatten Then Rebuild
The simplest approach is:
- Flatten the matrix into one list.
- Split that list into rows of length
c.
For example:
mat = [[1, 2], [3, 4]]Flattened form:
[1, 2, 3, 4]If r = 1 and c = 4, the result is:
[[1, 2, 3, 4]]This is easy to understand, but it uses an extra list.
Key Insight
We can map every element by its linear index.
In row-major order, each element has a single position number:
0, 1, 2, 3, ...For the original matrix with n columns:
old_row = index // n
old_col = index % nFor the new matrix with c columns:
new_row = index // c
new_col = index % cSo we can walk through all linear indices and copy each value from the old position to the new position.
Algorithm
- Let
m = len(mat)andn = len(mat[0]). - Check whether:
m * n == r * c - If not, return
mat. - Create an empty result matrix with
rrows andccolumns. - For each linear index from
0tom * n - 1:- Convert it to the old matrix position.
- Convert it to the new matrix position.
- Copy the value.
- Return the result matrix.
Correctness
The reshape operation must preserve row-major order.
Every element in the original matrix has a unique row-major index:
index = old_row * n + old_colThe algorithm uses this same index to place the element into the reshaped matrix:
new_row = index // c
new_col = index % cTherefore, the first element in row-major order goes to the first position of the result, the second element goes to the second position, and so on.
If m * n != r * c, the target matrix cannot contain exactly the same elements, so returning the original matrix is required.
If the counts match, every original element is copied exactly once to exactly one result position. Thus the algorithm returns the correct reshaped matrix.
Complexity
Let m be the number of rows and n be the number of columns in mat.
| Metric | Value | Why |
|---|---|---|
| Time | O(mn) | Every element is copied once |
| Space | O(mn) | The output matrix stores all elements |
The extra working space aside from the output is O(1).
Implementation
class Solution:
def matrixReshape(self, mat: list[list[int]], r: int, c: int) -> list[list[int]]:
m = len(mat)
n = len(mat[0])
if m * n != r * c:
return mat
result = [[0] * c for _ in range(r)]
for index in range(m * n):
old_row = index // n
old_col = index % n
new_row = index // c
new_col = index % c
result[new_row][new_col] = mat[old_row][old_col]
return resultCode Explanation
We first read the original matrix shape:
m = len(mat)
n = len(mat[0])Then we check whether the reshape is valid:
if m * n != r * c:
return matIf the total element count changes, we cannot reshape.
Next, we create the output matrix:
result = [[0] * c for _ in range(r)]Then each linear index is mapped to both matrices.
Original position:
old_row = index // n
old_col = index % nNew position:
new_row = index // c
new_col = index % cThen we copy the value:
result[new_row][new_col] = mat[old_row][old_col]Flatten-Based Version
Python also allows a compact implementation using a flattened list.
class Solution:
def matrixReshape(self, mat: list[list[int]], r: int, c: int) -> list[list[int]]:
m = len(mat)
n = len(mat[0])
if m * n != r * c:
return mat
values = []
for row in mat:
for value in row:
values.append(value)
result = []
for i in range(0, len(values), c):
result.append(values[i:i + c])
return resultThis version is easier to read, but it stores an extra flattened list.
Testing
def run_tests():
s = Solution()
assert s.matrixReshape([[1, 2], [3, 4]], 1, 4) == [[1, 2, 3, 4]]
original = [[1, 2], [3, 4]]
assert s.matrixReshape(original, 2, 4) == original
assert s.matrixReshape([[1, 2, 3, 4]], 2, 2) == [[1, 2], [3, 4]]
assert s.matrixReshape([[1], [2], [3], [4]], 2, 2) == [[1, 2], [3, 4]]
assert s.matrixReshape([[1]], 1, 1) == [[1]]
print("all tests passed")
run_tests()| Test | Why |
|---|---|
[[1, 2], [3, 4]], 1 x 4 | Valid reshape sample |
[[1, 2], [3, 4]], 2 x 4 | Impossible reshape |
1 x 4 to 2 x 2 | Splits one row into two rows |
4 x 1 to 2 x 2 | Combines column values into rows |
[[1]], 1 x 1 | Smallest matrix |