A clear explanation of removing all occurrences of a value from an array in place using a write pointer.
Problem Restatement
We are given an integer array nums and an integer val.
We need to remove every occurrence of val from nums in place. After removal, the first k elements of nums should contain only the values that are not equal to val.
Return k, the number of remaining elements.
The values after the first k positions do not matter. The order of the remaining elements may be changed. The problem requires O(1) extra memory.
Input and Output
| Item | Meaning |
|---|---|
| Input | An integer array nums and an integer val |
| Output | The number of elements not equal to val |
| Required mutation | Put the kept elements in the first k positions |
| Extra space | O(1) |
Function shape:
def removeElement(nums: list[int], val: int) -> int:
...Examples
Example 1:
nums = [3, 2, 2, 3]
val = 3We remove all 3s.
The remaining values are:
[2, 2]So we return:
2The first two elements of nums should be 2 and 2.
Example 2:
nums = [0, 1, 2, 2, 3, 0, 4, 2]
val = 2We remove all 2s.
The remaining values are:
[0, 1, 3, 0, 4]So we return:
5The first five elements of nums should contain those five values.
First Thought: Brute Force
A simple solution is to create a new array that stores only values different from val.
class Solution:
def removeElement(self, nums: list[int], val: int) -> int:
kept = []
for num in nums:
if num != val:
kept.append(num)
for i in range(len(kept)):
nums[i] = kept[i]
return len(kept)This gives the right answer, but it uses another array.
Problem With Brute Force
The problem asks us to modify nums in place with constant extra memory.
The brute force version uses O(n) extra memory because kept may store almost every element.
We need the same filtering idea, but we should write the kept values directly into the front of nums.
Key Insight
Use a write pointer.
The write pointer tells us where the next valid value should go.
As we scan the array, every value different from val should be kept. Every value equal to val should be skipped.
| Pointer | Meaning |
|---|---|
read | Scans every value in the original array |
write | Marks the next position for a kept value |
At all times, nums[0:write] contains the values we have decided to keep.
Algorithm
Start with:
write = 0Then scan every number in nums.
For each number:
if num != val:
nums[write] = num
write += 1When the loop ends, write is the number of elements not equal to val.
Return write.
Correctness
At the start, write = 0, so the kept prefix is empty.
During the scan, when we see a value equal to val, we skip it. This is correct because that value should not appear in the first k positions.
When we see a value different from val, we copy it into nums[write]. This appends it to the kept prefix. Then we increase write, so the prefix length grows by one.
After processing every element, the prefix nums[0:write] contains exactly the elements that are not equal to val. Therefore, write is exactly the required answer k.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | We scan the array once |
| Space | O(1) | We only use one integer pointer |
Implementation
class Solution:
def removeElement(self, nums: list[int], val: int) -> int:
write = 0
for num in nums:
if num != val:
nums[write] = num
write += 1
return writeCode Explanation
We start with write = 0.
write = 0This means no valid elements have been written yet.
Then we scan each value:
for num in nums:If the value should stay, we place it at the next available position:
nums[write] = numThen we move write forward:
write += 1If the value equals val, we do nothing. Skipping it removes it from the valid prefix.
Finally:
return writeThis returns the number of values kept.
Testing
def check(nums: list[int], val: int, expected: list[int]) -> None:
original = nums[:]
k = Solution().removeElement(nums, val)
assert k == len(expected), (original, val, k, expected)
assert nums[:k] == expected, (original, val, nums[:k], expected)
def run_tests():
check([3, 2, 2, 3], 3, [2, 2])
check([0, 1, 2, 2, 3, 0, 4, 2], 2, [0, 1, 3, 0, 4])
check([], 1, [])
check([1, 1, 1], 1, [])
check([1, 2, 3], 4, [1, 2, 3])
check([4], 4, [])
check([4], 3, [4])
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
[3,2,2,3], val = 3 | Basic removal |
[0,1,2,2,3,0,4,2], val = 2 | Multiple removed values |
| Empty array | No values to process |
| All values removed | Return 0 |
| No values removed | Return original length |
| Single removed value | Minimum non-empty removal |
| Single kept value | Minimum non-empty kept case |