A detailed explanation of removing the nth node from the end of a singly linked list using two pointers and a dummy node.
Problem Restatement
We are given the head of a linked list and an integer n.
We need to remove the nth node from the end of the list and return the head of the modified list.
The follow-up asks whether we can solve it in one pass. The constraints say the list size is sz, where 1 <= sz <= 30, node values are between 0 and 100, and 1 <= n <= sz.
Input and Output
| Item | Meaning |
|---|---|
| Input | Head of a singly linked list and integer n |
| Output | Head of the linked list after removing the target node |
| Target | The nth node from the end |
| Constraint | 1 <= n <= list size |
| Follow-up | Solve in one pass |
Example function shape:
def removeNthFromEnd(head: ListNode, n: int) -> ListNode:
...Examples
Example 1:
head = [1, 2, 3, 4, 5]
n = 2The second node from the end is:
4Remove it:
[1, 2, 3, 5]Output:
[1, 2, 3, 5]Example 2:
head = [1]
n = 1The only node is also the first node from the end.
After removing it:
[]Output:
[]Example 3:
head = [1, 2]
n = 1The first node from the end is:
2Remove it:
[1]Output:
[1]First Thought: Count the Length
The direct solution is to count the number of nodes first.
Suppose the list length is length.
The nth node from the end is the:
length - nindex from the start, using zero-based indexing.
For example:
head = [1, 2, 3, 4, 5]
n = 2
length = 5The index to remove is:
5 - 2 = 3Index 3 contains value 4.
Code:
class Solution:
def removeNthFromEnd(self, head: Optional[ListNode], n: int) -> Optional[ListNode]:
dummy = ListNode(0, head)
length = 0
current = head
while current:
length += 1
current = current.next
steps = length - n
current = dummy
for _ in range(steps):
current = current.next
current.next = current.next.next
return dummy.nextThis works, but it uses two passes: one pass to count length, then another pass to remove the node.
Key Insight
Use two pointers with a fixed gap of n.
Create two pointers:
fast
slowMove fast ahead by n nodes.
Then move both fast and slow together.
When fast reaches the end, slow will be right before the node we need to remove.
A dummy node helps because the node to remove may be the head.
Without a dummy node, removing the first node needs special handling.
With a dummy node, every removal looks like:
slow.next = slow.next.nextVisual Walkthrough
Use:
head = [1, 2, 3, 4, 5]
n = 2Add a dummy node:
dummy -> 1 -> 2 -> 3 -> 4 -> 5Start both pointers at dummy:
fast = dummy
slow = dummyMove fast ahead by 2 nodes:
dummy -> 1 -> 2 -> 3 -> 4 -> 5
^
fast
^
slowNow move both together until fast.next is None.
Step 1:
dummy -> 1 -> 2 -> 3 -> 4 -> 5
^
fast
^
slowStep 2:
dummy -> 1 -> 2 -> 3 -> 4 -> 5
^
fast
^
slowStep 3:
dummy -> 1 -> 2 -> 3 -> 4 -> 5
^
fast
^
slowNow fast is at the last node.
slow.next is the node to remove:
4Remove it:
slow.next = slow.next.nextResult:
1 -> 2 -> 3 -> 5Algorithm
- Create a dummy node whose
nextpoints tohead. - Set both
fastandslowtodummy. - Move
fastforwardntimes. - Move both pointers forward while
fast.nextexists. - Now
slow.nextis the target node. - Remove it by linking around it.
- Return
dummy.next.
Correctness
After moving fast forward n times, there are exactly n nodes between slow and fast in the linked list traversal order.
Then both pointers move together one node at a time, preserving this gap.
When fast reaches the last node, slow is exactly one node before the nth node from the end.
So slow.next is the node that must be removed.
The assignment:
slow.next = slow.next.nextremoves exactly that node while preserving the rest of the list.
The dummy node guarantees this also works when the removed node is the original head.
Therefore the algorithm returns the correct modified list.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(sz) | Each pointer moves through the list at most once |
| Space | O(1) | Only a few pointers are stored |
Here sz is the number of nodes in the list.
Implementation
from typing import Optional
class ListNode:
def __init__(self, val: int = 0, next: Optional["ListNode"] = None):
self.val = val
self.next = next
class Solution:
def removeNthFromEnd(
self,
head: Optional[ListNode],
n: int
) -> Optional[ListNode]:
dummy = ListNode(0, head)
fast = dummy
slow = dummy
for _ in range(n):
fast = fast.next
while fast.next:
fast = fast.next
slow = slow.next
slow.next = slow.next.next
return dummy.nextCode Explanation
Create a dummy node:
dummy = ListNode(0, head)This handles removal of the head node.
Initialize both pointers:
fast = dummy
slow = dummyMove fast ahead by n nodes:
for _ in range(n):
fast = fast.nextMove both pointers until fast reaches the last node:
while fast.next:
fast = fast.next
slow = slow.nextNow slow.next is the node to remove.
Remove it:
slow.next = slow.next.nextReturn the real head:
return dummy.nextTesting
def build_list(values):
dummy = ListNode()
current = dummy
for value in values:
current.next = ListNode(value)
current = current.next
return dummy.next
def to_list(head):
result = []
while head:
result.append(head.val)
head = head.next
return result
def run_tests():
s = Solution()
head = build_list([1, 2, 3, 4, 5])
assert to_list(s.removeNthFromEnd(head, 2)) == [1, 2, 3, 5]
head = build_list([1])
assert to_list(s.removeNthFromEnd(head, 1)) == []
head = build_list([1, 2])
assert to_list(s.removeNthFromEnd(head, 1)) == [1]
head = build_list([1, 2])
assert to_list(s.removeNthFromEnd(head, 2)) == [2]
head = build_list([1, 2, 3])
assert to_list(s.removeNthFromEnd(head, 3)) == [2, 3]
print("all tests passed")
run_tests()| Test | Why |
|---|---|
[1, 2, 3, 4, 5], n = 2 | Standard middle removal |
[1], n = 1 | Remove only node |
[1, 2], n = 1 | Remove tail |
[1, 2], n = 2 | Remove head |
[1, 2, 3], n = 3 | Head removal with longer list |