A detailed explanation of finding all unique triplets that sum to zero using sorting and two pointers.
Problem Restatement
We are given an integer array nums.
We need to return all unique triplets:
[nums[i], nums[j], nums[k]]such that:
i != j
i != k
j != kand:
nums[i] + nums[j] + nums[k] == 0The solution set must not contain duplicate triplets. The order of the output and the order of values inside each triplet do not matter. The constraints are 3 <= nums.length <= 3000 and -10^5 <= nums[i] <= 10^5.
Input and Output
| Item | Meaning |
|---|---|
| Input | An integer array nums |
| Output | All unique triplets whose sum is 0 |
| Triplet rule | The three indices must be different |
| Duplicate rule | The answer must not contain duplicate triplets |
| Constraint | 3 <= nums.length <= 3000 |
Example function shape:
def threeSum(nums: list[int]) -> list[list[int]]:
...Examples
Example 1:
nums = [-1, 0, 1, 2, -1, -4]The valid unique triplets are:
[-1, -1, 2]
[-1, 0, 1]Output:
[[-1, -1, 2], [-1, 0, 1]]Example 2:
nums = [0, 1, 1]The only possible triplet is:
[0, 1, 1]Its sum is:
2Output:
[]Example 3:
nums = [0, 0, 0]The only possible unique triplet is:
[0, 0, 0]Output:
[[0, 0, 0]]First Thought: Try Every Triplet
The direct method is to check every group of three indices.
For every i, j, and k, check whether:
nums[i] + nums[j] + nums[k] == 0To remove duplicates, we can sort each valid triplet and store it in a set.
class Solution:
def threeSum(self, nums: list[int]) -> list[list[int]]:
found = set()
n = len(nums)
for i in range(n):
for j in range(i + 1, n):
for k in range(j + 1, n):
if nums[i] + nums[j] + nums[k] == 0:
triplet = tuple(sorted([nums[i], nums[j], nums[k]]))
found.add(triplet)
return [list(t) for t in found]This is correct, but too slow.
Problem With Brute Force
The brute force solution checks all triples.
| Metric | Value |
|---|---|
| Time | O(n^3) |
| Space | O(r) |
Here, r is the number of unique triplets stored in the result.
With n up to 3000, O(n^3) is not practical.
Key Insight
Sort the array first.
After sorting, we can fix one number and solve the remaining part as a two-sum problem using two pointers.
Suppose we fix:
a = nums[i]Then we need two numbers:
b + c = -aBecause the array is sorted, we can search for b and c using:
left = i + 1
right = len(nums) - 1If:
a + nums[left] + nums[right] < 0the sum is too small, so we move left rightward to get a larger value.
If:
a + nums[left] + nums[right] > 0the sum is too large, so we move right leftward to get a smaller value.
If the sum is exactly 0, we record the triplet.
Handling Duplicates
Sorting also makes duplicate handling easier.
When choosing the fixed number nums[i], skip it if it equals the previous fixed number:
if i > 0 and nums[i] == nums[i - 1]:
continueThis prevents generating the same group again.
After finding a valid triplet, move both pointers and skip repeated values:
while left < right and nums[left] == nums[left - 1]:
left += 1
while left < right and nums[right] == nums[right + 1]:
right -= 1This prevents repeated triplets with the same fixed number.
Visual Walkthrough
Use:
nums = [-1, 0, 1, 2, -1, -4]Sort it:
[-4, -1, -1, 0, 1, 2]Start with:
i = 0
a = -4
left = 1
right = 5Sum:
-4 + -1 + 2 = -3Too small, so move left.
No triplet with -4 works.
Next:
i = 1
a = -1
left = 2
right = 5Sum:
-1 + -1 + 2 = 0Record:
[-1, -1, 2]Move both pointers.
Now:
left = 3
right = 4Sum:
-1 + 0 + 1 = 0Record:
[-1, 0, 1]Next i = 2 is also -1, so skip it.
The final answer is:
[[-1, -1, 2], [-1, 0, 1]]Algorithm
- Sort
nums. - Create an empty result list.
- For each index
i:- if
nums[i]is a duplicate fixed value, skip it - if
nums[i] > 0, stop early because all later numbers are also positive - set
left = i + 1 - set
right = len(nums) - 1
- if
- While
left < right:- compute the sum
- if the sum is less than
0, moveleft - if the sum is greater than
0, moveright - otherwise, record the triplet and skip duplicates
- Return the result.
Correctness
After sorting, every triplet can be written in nondecreasing order.
The outer loop chooses the first value of the triplet. For a fixed i, the two-pointer scan searches all pairs to the right of i.
When the current sum is too small, increasing left is the only useful move because moving right leftward would make the sum even smaller or equal. When the current sum is too large, decreasing right is the only useful move because moving left rightward would make the sum even larger or equal.
So the two-pointer scan finds every valid pair for the fixed i.
The duplicate checks skip repeated fixed values and repeated pointer values after a valid triplet is found. Since equal values create the same value triplet, skipping them removes duplicates without removing any unique triplet.
Therefore the algorithm returns exactly all unique triplets whose sum is zero.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n^2) | Sorting costs O(n log n), then each fixed index uses a linear two-pointer scan |
| Space | O(1) extra | Ignoring the output list; sorting may use implementation-dependent space |
Implementation
class Solution:
def threeSum(self, nums: list[int]) -> list[list[int]]:
nums.sort()
result = []
n = len(nums)
for i in range(n - 2):
if i > 0 and nums[i] == nums[i - 1]:
continue
if nums[i] > 0:
break
left = i + 1
right = n - 1
while left < right:
total = nums[i] + nums[left] + nums[right]
if total < 0:
left += 1
elif total > 0:
right -= 1
else:
result.append([nums[i], nums[left], nums[right]])
left += 1
right -= 1
while left < right and nums[left] == nums[left - 1]:
left += 1
while left < right and nums[right] == nums[right + 1]:
right -= 1
return resultCode Explanation
Sort the array:
nums.sort()This enables two pointers and makes duplicates adjacent.
Loop through possible fixed values:
for i in range(n - 2):Skip duplicate fixed values:
if i > 0 and nums[i] == nums[i - 1]:
continueStop early when the fixed value is positive:
if nums[i] > 0:
breakSince the array is sorted, every later value is also positive, so no zero-sum triplet can appear.
Set the two pointers:
left = i + 1
right = n - 1Compute the current sum:
total = nums[i] + nums[left] + nums[right]Move pointers based on the sum.
Too small:
left += 1Too large:
right -= 1Exactly zero:
result.append([nums[i], nums[left], nums[right]])Then skip duplicate values around both pointers.
Testing
def normalize(result):
return sorted([tuple(x) for x in result])
def run_tests():
s = Solution()
assert normalize(s.threeSum([-1, 0, 1, 2, -1, -4])) == [
(-1, -1, 2),
(-1, 0, 1),
]
assert normalize(s.threeSum([0, 1, 1])) == []
assert normalize(s.threeSum([0, 0, 0])) == [(0, 0, 0)]
assert normalize(s.threeSum([0, 0, 0, 0])) == [(0, 0, 0)]
assert normalize(s.threeSum([-2, 0, 1, 1, 2])) == [
(-2, 0, 2),
(-2, 1, 1),
]
assert normalize(s.threeSum([1, 2, -2, -1])) == []
print("all tests passed")
run_tests()| Test | Why |
|---|---|
[-1, 0, 1, 2, -1, -4] | Standard example |
[0, 1, 1] | No valid triplet |
[0, 0, 0] | One all-zero triplet |
[0, 0, 0, 0] | Duplicate all-zero triplets must collapse |
[-2, 0, 1, 1, 2] | Duplicate values can still form unique triplets |
[1, 2, -2, -1] | Mixed signs but no valid answer |