Skip to content

LeetCode 27: Remove Element

A clear explanation of removing all occurrences of a value from an array in place using a write pointer.

Problem Restatement

We are given an integer array nums and an integer val.

We need to remove every occurrence of val from nums in place. After removal, the first k elements of nums should contain only the values that are not equal to val.

Return k, the number of remaining elements.

The values after the first k positions do not matter. The order of the remaining elements may be changed. The problem requires O(1) extra memory.

Input and Output

ItemMeaning
InputAn integer array nums and an integer val
OutputThe number of elements not equal to val
Required mutationPut the kept elements in the first k positions
Extra spaceO(1)

Function shape:

def removeElement(nums: list[int], val: int) -> int:
    ...

Examples

Example 1:

nums = [3, 2, 2, 3]
val = 3

We remove all 3s.

The remaining values are:

[2, 2]

So we return:

2

The first two elements of nums should be 2 and 2.

Example 2:

nums = [0, 1, 2, 2, 3, 0, 4, 2]
val = 2

We remove all 2s.

The remaining values are:

[0, 1, 3, 0, 4]

So we return:

5

The first five elements of nums should contain those five values.

First Thought: Brute Force

A simple solution is to create a new array that stores only values different from val.

class Solution:
    def removeElement(self, nums: list[int], val: int) -> int:
        kept = []

        for num in nums:
            if num != val:
                kept.append(num)

        for i in range(len(kept)):
            nums[i] = kept[i]

        return len(kept)

This gives the right answer, but it uses another array.

Problem With Brute Force

The problem asks us to modify nums in place with constant extra memory.

The brute force version uses O(n) extra memory because kept may store almost every element.

We need the same filtering idea, but we should write the kept values directly into the front of nums.

Key Insight

Use a write pointer.

The write pointer tells us where the next valid value should go.

As we scan the array, every value different from val should be kept. Every value equal to val should be skipped.

PointerMeaning
readScans every value in the original array
writeMarks the next position for a kept value

At all times, nums[0:write] contains the values we have decided to keep.

Algorithm

Start with:

write = 0

Then scan every number in nums.

For each number:

if num != val:
    nums[write] = num
    write += 1

When the loop ends, write is the number of elements not equal to val.

Return write.

Correctness

At the start, write = 0, so the kept prefix is empty.

During the scan, when we see a value equal to val, we skip it. This is correct because that value should not appear in the first k positions.

When we see a value different from val, we copy it into nums[write]. This appends it to the kept prefix. Then we increase write, so the prefix length grows by one.

After processing every element, the prefix nums[0:write] contains exactly the elements that are not equal to val. Therefore, write is exactly the required answer k.

Complexity

MetricValueWhy
TimeO(n)We scan the array once
SpaceO(1)We only use one integer pointer

Implementation

class Solution:
    def removeElement(self, nums: list[int], val: int) -> int:
        write = 0

        for num in nums:
            if num != val:
                nums[write] = num
                write += 1

        return write

Code Explanation

We start with write = 0.

write = 0

This means no valid elements have been written yet.

Then we scan each value:

for num in nums:

If the value should stay, we place it at the next available position:

nums[write] = num

Then we move write forward:

write += 1

If the value equals val, we do nothing. Skipping it removes it from the valid prefix.

Finally:

return write

This returns the number of values kept.

Testing

def check(nums: list[int], val: int, expected: list[int]) -> None:
    original = nums[:]
    k = Solution().removeElement(nums, val)

    assert k == len(expected), (original, val, k, expected)
    assert nums[:k] == expected, (original, val, nums[:k], expected)

def run_tests():
    check([3, 2, 2, 3], 3, [2, 2])
    check([0, 1, 2, 2, 3, 0, 4, 2], 2, [0, 1, 3, 0, 4])
    check([], 1, [])
    check([1, 1, 1], 1, [])
    check([1, 2, 3], 4, [1, 2, 3])
    check([4], 4, [])
    check([4], 3, [4])

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
[3,2,2,3], val = 3Basic removal
[0,1,2,2,3,0,4,2], val = 2Multiple removed values
Empty arrayNo values to process
All values removedReturn 0
No values removedReturn original length
Single removed valueMinimum non-empty removal
Single kept valueMinimum non-empty kept case