A clear explanation of finding the kth smallest pair distance using sorting, binary search on the answer, and a two-pointer count.
Problem Restatement
We are given an integer array nums and an integer k.
For every pair of indices:
0 <= i < j < len(nums)the distance between the pair is:
abs(nums[i] - nums[j])We need to return the kth smallest pair distance.
The official problem defines the distance of two integers as their absolute difference and asks for the kth smallest distance among all pairs nums[i] and nums[j] where i < j.
Input and Output
| Item | Meaning |
|---|---|
| Input | Integer array nums |
| Input | Integer k |
| Output | The kth smallest pair distance |
| Pair rule | Use pairs (i, j) where i < j |
| Distance | abs(nums[i] - nums[j]) |
The function shape is:
class Solution:
def smallestDistancePair(self, nums: list[int], k: int) -> int:
...Examples
Example 1:
nums = [1, 3, 1]
k = 1All pairs are:
| Pair | Distance |
|---|---|
(1, 3) | 2 |
(1, 1) | 0 |
(3, 1) | 2 |
Sorted distances:
[0, 2, 2]The 1st smallest distance is:
0Output:
0Example 2:
nums = [1, 6, 1]
k = 3All distances are:
0, 5, 5The 3rd smallest distance is:
5Output:
5First Thought: Generate All Pair Distances
A direct approach is to compute every pair distance.
distances = []
for i in range(len(nums)):
for j in range(i + 1, len(nums)):
distances.append(abs(nums[i] - nums[j]))Then sort distances and return:
distances[k - 1]This is simple and correct.
Problem With Brute Force
If nums has length n, the number of pairs is:
n * (n - 1) / 2So generating all distances costs O(n^2) space.
Sorting those distances costs:
O(n^2 log n)This is too expensive for large inputs.
We need to find the kth distance without explicitly storing every pair distance.
Key Insight
The answer is a distance value.
The smallest possible distance is:
0The largest possible distance is:
max(nums) - min(nums)After sorting nums, we can binary search over this distance range.
For a guessed distance mid, count how many pairs have distance less than or equal to mid.
If at least k pairs have distance <= mid, then the kth smallest distance is at most mid.
If fewer than k pairs have distance <= mid, then the kth smallest distance is greater than mid.
This gives a monotonic condition suitable for binary search.
Counting Pairs With Distance at Most mid
Sort the array first.
For each right index, maintain the smallest left index such that:
nums[right] - nums[left] <= midBecause the array is sorted, all indices from left to right - 1 form valid pairs with right.
The number of valid pairs ending at right is:
right - leftIf the distance is too large, move left forward.
This is a two-pointer sliding window over the sorted array.
Algorithm
- Sort
nums. - Set:
low = 0 high = nums[-1] - nums[0] - Binary search while
low < high:- Set
mid = (low + high) // 2. - Count how many pairs have distance
<= mid. - If the count is at least
k, sethigh = mid. - Otherwise, set
low = mid + 1.
- Set
- Return
low.
The count function:
- Set
left = 0. - Set
count = 0. - For each
right:- While
nums[right] - nums[left] > distance, incrementleft. - Add
right - lefttocount.
- While
- Return
count.
Correctness
After sorting, every pair distance can be written as:
nums[j] - nums[i]where i < j, because nums[j] >= nums[i].
For a fixed distance d, define count(d) as the number of pairs with distance at most d.
If d increases, no previously valid pair becomes invalid. Therefore count(d) is monotonic non-decreasing.
The kth smallest distance is the smallest distance d such that at least k pairs have distance at most d.
The binary search maintains this target value inside [low, high].
When count(mid) >= k, there are enough pairs with distance at most mid, so the answer is mid or smaller. We set high = mid.
When count(mid) < k, fewer than k pairs have distance at most mid, so the kth smallest distance must be larger. We set low = mid + 1.
The count function is correct because for each right, it advances left until nums[right] - nums[left] <= d. Since the array is sorted, every index between left and right - 1 also gives a distance at most d, and every index before left gives a distance greater than d. Thus it adds exactly the number of valid pairs ending at right.
When the binary search finishes, low == high, and this value is the smallest distance with at least k valid pairs. That is exactly the kth smallest pair distance.
Complexity
Let n be the length of nums, and let:
W = max(nums) - min(nums)| Metric | Value | Why |
|---|---|---|
| Time | O(n log n + n log W) | Sort once, then each binary search step counts pairs in O(n) |
| Space | O(1) extra | Sorting aside, only pointers and counters are used |
In Python, sorting may use implementation-dependent auxiliary memory.
Implementation
class Solution:
def smallestDistancePair(self, nums: list[int], k: int) -> int:
nums.sort()
def count_pairs_at_most(distance: int) -> int:
count = 0
left = 0
for right in range(len(nums)):
while nums[right] - nums[left] > distance:
left += 1
count += right - left
return count
low = 0
high = nums[-1] - nums[0]
while low < high:
mid = (low + high) // 2
if count_pairs_at_most(mid) >= k:
high = mid
else:
low = mid + 1
return lowCode Explanation
First, sort the array:
nums.sort()This lets us compute pair distance without abs.
The helper function counts pairs whose distance is at most a given value:
def count_pairs_at_most(distance: int) -> int:The left pointer tracks the first valid index for the current right pointer:
left = 0For each right, shrink the window until the distance fits:
while nums[right] - nums[left] > distance:
left += 1Then all starts from left through right - 1 are valid:
count += right - leftThe binary search range is the possible answer range:
low = 0
high = nums[-1] - nums[0]If at least k pairs have distance at most mid, try smaller:
if count_pairs_at_most(mid) >= k:
high = midOtherwise, the answer is larger:
else:
low = mid + 1Return the smallest feasible distance:
return lowExample Walkthrough
Use:
nums = [1, 3, 1]
k = 1After sorting:
nums = [1, 1, 3]Search range:
low = 0
high = 2Check mid = 1.
Count pairs with distance at most 1:
| Pair | Distance | Counted |
|---|---|---|
(1, 1) | 0 | Yes |
(1, 3) | 2 | No |
(1, 3) | 2 | No |
Count is 1.
Since count >= k, the answer is at most 1.
Set:
high = 1Check mid = 0.
Pairs with distance at most 0:
| Pair | Distance | Counted |
|---|---|---|
(1, 1) | 0 | Yes |
Count is 1.
Since count >= k, set:
high = 0Now:
low == high == 0Return:
0Testing
def test_smallest_distance_pair():
s = Solution()
assert s.smallestDistancePair([1, 3, 1], 1) == 0
assert s.smallestDistancePair([1, 6, 1], 3) == 5
assert s.smallestDistancePair([1, 1, 1], 2) == 0
assert s.smallestDistancePair([1, 2, 3, 4], 3) == 1
assert s.smallestDistancePair([9, 10, 7, 10, 6, 1, 5, 4], 18) == 4
print("all tests passed")
test_smallest_distance_pair()Test coverage:
| Test | Why |
|---|---|
| Duplicate values | Confirms zero distance |
| kth is largest distance | Confirms high-end binary search |
| All equal values | Confirms all distances are zero |
| Many small distances | Confirms pair counting |
| Unsorted input | Confirms sorting step works |