Skip to content

LeetCode 18: 4Sum

A detailed explanation of finding all unique quadruplets that sum to a target using sorting and two pointers.

Problem Restatement

We are given an integer array nums and an integer target.

We need to return all unique quadruplets:

[nums[a], nums[b], nums[c], nums[d]]

such that:

0 <= a, b, c, d < n

the four indices are distinct, and:

nums[a] + nums[b] + nums[c] + nums[d] == target

The answer may be returned in any order, but duplicate quadruplets must not appear. The constraints are 1 <= nums.length <= 200, -10^9 <= nums[i] <= 10^9, and -10^9 <= target <= 10^9.

Input and Output

ItemMeaning
InputAn integer array nums and an integer target
OutputAll unique quadruplets whose sum equals target
Index ruleThe four indices must be distinct
Duplicate ruleThe result must not contain duplicate quadruplets
Constraint1 <= nums.length <= 200

Example function shape:

def fourSum(nums: list[int], target: int) -> list[list[int]]:
    ...

Examples

Example 1:

nums = [1, 0, -1, 0, -2, 2]
target = 0

The unique quadruplets are:

[-2, -1, 1, 2]
[-2, 0, 0, 2]
[-1, 0, 0, 1]

Output:

[[-2, -1, 1, 2], [-2, 0, 0, 2], [-1, 0, 0, 1]]

Example 2:

nums = [2, 2, 2, 2, 2]
target = 8

The only unique quadruplet is:

[2, 2, 2, 2]

Output:

[[2, 2, 2, 2]]

First Thought: Try Every Quadruplet

The direct method is to check every group of four indices.

For every a, b, c, and d, compute the sum and keep the quadruplet if the sum equals target.

To avoid duplicates, sort each valid quadruplet and store it in a set.

class Solution:
    def fourSum(self, nums: list[int], target: int) -> list[list[int]]:
        found = set()
        n = len(nums)

        for a in range(n):
            for b in range(a + 1, n):
                for c in range(b + 1, n):
                    for d in range(c + 1, n):
                        total = nums[a] + nums[b] + nums[c] + nums[d]

                        if total == target:
                            quad = tuple(sorted([
                                nums[a],
                                nums[b],
                                nums[c],
                                nums[d],
                            ]))
                            found.add(quad)

        return [list(q) for q in found]

This is correct, but too slow.

Problem With Brute Force

There are O(n^4) possible quadruplets.

With n up to 200, this is too much work.

MetricValue
TimeO(n^4)
SpaceO(r)

Here, r is the number of unique quadruplets stored.

Key Insight

This problem is an extension of 3Sum.

Sort the array first.

Then fix the first two numbers using two loops.

After fixing:

nums[i]
nums[j]

we need two more numbers whose sum equals:

target - nums[i] - nums[j]

Since the array is sorted, we can find the remaining pair using two pointers:

left = j + 1
right = n - 1

If the current sum is too small, move left rightward.

If the current sum is too large, move right leftward.

If the current sum equals target, record the quadruplet and skip duplicate pointer values.

Handling Duplicates

Sorting puts equal values next to each other.

We skip duplicate values at every choice level.

For the first fixed number:

if i > 0 and nums[i] == nums[i - 1]:
    continue

For the second fixed number:

if j > i + 1 and nums[j] == nums[j - 1]:
    continue

After recording a valid quadruplet, move both pointers and skip repeated values:

while left < right and nums[left] == nums[left - 1]:
    left += 1

while left < right and nums[right] == nums[right + 1]:
    right -= 1

This ensures each value quadruplet appears once.

Visual Walkthrough

Use:

nums = [1, 0, -1, 0, -2, 2]
target = 0

Sort:

[-2, -1, 0, 0, 1, 2]

Fix:

i = 0
nums[i] = -2

Then fix:

j = 1
nums[j] = -1

Now search with:

left = 2
right = 5

Current sum:

-2 + -1 + 0 + 2 = -1

Too small, so move left.

Now:

left = 3
right = 5

Current sum:

-2 + -1 + 0 + 2 = -1

Still too small. Move left.

Now:

left = 4
right = 5

Current sum:

-2 + -1 + 1 + 2 = 0

Record:

[-2, -1, 1, 2]

Continue scanning other j values.

Fix:

i = 0
nums[i] = -2
j = 2
nums[j] = 0

Use left = 3, right = 5.

Current sum:

-2 + 0 + 0 + 2 = 0

Record:

[-2, 0, 0, 2]

Later, fixing i = 1, j = 2 gives:

[-1, 0, 0, 1]

Final result:

[[-2, -1, 1, 2], [-2, 0, 0, 2], [-1, 0, 0, 1]]

Algorithm

  1. Sort nums.
  2. Create an empty result list.
  3. Loop i from 0 to n - 4.
    • skip duplicate nums[i]
  4. Loop j from i + 1 to n - 3.
    • skip duplicate nums[j]
  5. Set:
    • left = j + 1
    • right = n - 1
  6. While left < right:
    • compute the four-number sum
    • if sum is too small, move left
    • if sum is too large, move right
    • if sum equals target:
      • record the quadruplet
      • move both pointers
      • skip duplicate pointer values
  7. Return the result.

Correctness

After sorting, every quadruplet can be represented in nondecreasing order.

The outer loops choose the first and second values of a quadruplet. For each fixed pair (i, j), the two-pointer scan searches all valid pairs to the right of j.

When the sum is smaller than target, moving right leftward would only make the sum smaller or equal. So the only useful move is increasing left.

When the sum is larger than target, moving left rightward would only make the sum larger or equal. So the only useful move is decreasing right.

Therefore, for each fixed pair, the two-pointer scan finds every matching remaining pair.

The duplicate checks skip equal values at the same decision level. Equal values at the same level produce the same value quadruplets, so skipping them removes duplicate outputs without removing any unique quadruplet.

Therefore the algorithm returns exactly all unique quadruplets whose sum equals target.

Complexity

MetricValueWhy
TimeO(n^3)Two fixed loops and one two-pointer scan
SpaceO(1) extraIgnoring output and sorting implementation details

Sorting costs O(n log n), which is dominated by O(n^3).

Implementation

class Solution:
    def fourSum(self, nums: list[int], target: int) -> list[list[int]]:
        nums.sort()

        n = len(nums)
        result = []

        for i in range(n - 3):
            if i > 0 and nums[i] == nums[i - 1]:
                continue

            for j in range(i + 1, n - 2):
                if j > i + 1 and nums[j] == nums[j - 1]:
                    continue

                left = j + 1
                right = n - 1

                while left < right:
                    total = (
                        nums[i]
                        + nums[j]
                        + nums[left]
                        + nums[right]
                    )

                    if total < target:
                        left += 1
                    elif total > target:
                        right -= 1
                    else:
                        result.append([
                            nums[i],
                            nums[j],
                            nums[left],
                            nums[right],
                        ])

                        left += 1
                        right -= 1

                        while left < right and nums[left] == nums[left - 1]:
                            left += 1

                        while left < right and nums[right] == nums[right + 1]:
                            right -= 1

        return result

Code Explanation

Sort the input:

nums.sort()

This lets us use two pointers and skip duplicates.

Choose the first fixed value:

for i in range(n - 3):

Skip repeated first values:

if i > 0 and nums[i] == nums[i - 1]:
    continue

Choose the second fixed value:

for j in range(i + 1, n - 2):

Skip repeated second values:

if j > i + 1 and nums[j] == nums[j - 1]:
    continue

Use two pointers for the last two values:

left = j + 1
right = n - 1

Compute the current sum:

total = nums[i] + nums[j] + nums[left] + nums[right]

Move according to the sum:

if total < target:
    left += 1
elif total > target:
    right -= 1

When the sum matches, record the quadruplet and skip duplicates.

Testing

def normalize(result):
    return sorted([tuple(x) for x in result])

def run_tests():
    s = Solution()

    assert normalize(s.fourSum([1, 0, -1, 0, -2, 2], 0)) == [
        (-2, -1, 1, 2),
        (-2, 0, 0, 2),
        (-1, 0, 0, 1),
    ]

    assert normalize(s.fourSum([2, 2, 2, 2, 2], 8)) == [
        (2, 2, 2, 2),
    ]

    assert normalize(s.fourSum([], 0)) == []
    assert normalize(s.fourSum([1, 2, 3], 6)) == []
    assert normalize(s.fourSum([0, 0, 0, 0], 0)) == [
        (0, 0, 0, 0),
    ]

    assert normalize(s.fourSum([-3, -1, 0, 2, 4, 5], 2)) == [
        (-3, -1, 2, 4),
    ]

    print("all tests passed")

run_tests()
TestWhy
[1, 0, -1, 0, -2, 2], target 0Standard example
[2, 2, 2, 2, 2], target 8Duplicate values collapse to one quadruplet
[], target 0Defensive empty input
[1, 2, 3], target 6Fewer than four numbers
[0, 0, 0, 0], target 0All-zero quadruplet
[-3, -1, 0, 2, 4, 5], target 2Mixed negative and positive values