A clear explanation of sorting values after applying a quadratic function using two pointers.
Problem Restatement
We are given a sorted integer array nums and three integers a, b, and c.
For every number x in nums, apply this function:
f(x) = a * x * x + b * x + cReturn all transformed values in sorted order.
The input array is already sorted in ascending order. The follow-up asks for an O(n) solution, rather than transforming everything and sorting afterward. The official examples include nums = [-4,-2,2,4], a = 1, b = 3, c = 5, with output [3,9,15,33], and the same nums with a = -1, producing [-23,-5,1,7].
Input and Output
| Item | Meaning |
|---|---|
| Input | Sorted array nums, integers a, b, c |
| Output | Sorted transformed array |
| Function | f(x) = ax² + bx + c |
| Target time | O(n) |
Example function shape:
def sortTransformedArray(nums: list[int], a: int, b: int, c: int) -> list[int]:
...Examples
Example 1:
nums = [-4, -2, 2, 4]
a = 1
b = 3
c = 5Transform each value:
x | f(x) |
|---|---|
-4 | 9 |
-2 | 3 |
2 | 15 |
4 | 33 |
The transformed values are:
[9, 3, 15, 33]Sorted:
[3, 9, 15, 33]Example 2:
nums = [-4, -2, 2, 4]
a = -1
b = 3
c = 5Transform each value:
x | f(x) |
|---|---|
-4 | -23 |
-2 | -5 |
2 | 7 |
4 | 1 |
Sorted:
[-23, -5, 1, 7]First Thought: Transform Then Sort
The simplest method is:
- Transform every number.
- Sort the transformed array.
class Solution:
def sortTransformedArray(
self,
nums: list[int],
a: int,
b: int,
c: int,
) -> list[int]:
transformed = []
for x in nums:
transformed.append(a * x * x + b * x + c)
transformed.sort()
return transformedThis is correct and short.
But sorting costs:
O(n log n)The follow-up asks for O(n), so we should use the fact that nums is already sorted.
Key Insight
The function is quadratic.
When a > 0, the parabola opens upward. The smallest values are near the vertex, and the largest values are toward the two ends.
So among the remaining numbers, the largest transformed value is at either the left end or the right end.
When a < 0, the parabola opens downward. The largest values are near the vertex, and the smallest values are toward the two ends.
So among the remaining numbers, the smallest transformed value is at either the left end or the right end.
When a = 0, the function is linear. The same two-pointer logic still works if we treat it with the a <= 0 branch.
This lets us use two pointers:
left = 0
right = len(nums) - 1At each step, compare:
f(nums[left])
f(nums[right])Then place the larger or smaller one into the answer, depending on the sign of a.
Algorithm
Create:
ans = [0] * len(nums)Use two pointers:
left = 0
right = len(nums) - 1If a >= 0:
- Larger values are at the ends.
- Fill
ansfrom right to left. - Compare
f(nums[left])andf(nums[right]). - Put the larger value at the current right position.
If a < 0:
- Smaller values are at the ends.
- Fill
ansfrom left to right. - Compare
f(nums[left])andf(nums[right]). - Put the smaller value at the current left position.
Correctness
The transformed values of a sorted input under a quadratic function form a sequence that decreases then increases when a > 0, or increases then decreases when a < 0.
For a > 0, the maximum remaining transformed value must be at one of the two ends. The algorithm chooses the larger end value and places it into the last unfilled position of ans. This position requires the largest remaining value, so the placement is correct.
After removing that chosen end, the same argument applies to the remaining subarray.
For a < 0, the minimum remaining transformed value must be at one of the two ends. The algorithm chooses the smaller end value and places it into the first unfilled position of ans. This position requires the smallest remaining value, so the placement is correct.
Each step places exactly one transformed value in its final sorted position. Since the loop processes every input value once, the final array is sorted and contains exactly all transformed values.
Complexity
Let n = len(nums).
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each number is transformed and placed once |
| Space | O(n) | The answer array stores n values |
Implementation
class Solution:
def sortTransformedArray(
self,
nums: list[int],
a: int,
b: int,
c: int,
) -> list[int]:
def f(x: int) -> int:
return a * x * x + b * x + c
n = len(nums)
ans = [0] * n
left = 0
right = n - 1
if a >= 0:
write = n - 1
while left <= right:
left_value = f(nums[left])
right_value = f(nums[right])
if left_value > right_value:
ans[write] = left_value
left += 1
else:
ans[write] = right_value
right -= 1
write -= 1
else:
write = 0
while left <= right:
left_value = f(nums[left])
right_value = f(nums[right])
if left_value < right_value:
ans[write] = left_value
left += 1
else:
ans[write] = right_value
right -= 1
write += 1
return ansCode Explanation
The helper function computes the transformed value:
def f(x: int) -> int:
return a * x * x + b * x + cThe answer array is preallocated:
ans = [0] * nThe two pointers start at both ends:
left = 0
right = n - 1When a >= 0, the larger values are found at the ends, so we fill from the back:
write = n - 1Each time, we compare both end values and write the larger one:
if left_value > right_value:
ans[write] = left_value
left += 1
else:
ans[write] = right_value
right -= 1When a < 0, the smaller values are found at the ends, so we fill from the front:
write = 0Each time, we compare both end values and write the smaller one:
if left_value < right_value:
ans[write] = left_value
left += 1
else:
ans[write] = right_value
right -= 1Testing
def run_tests():
s = Solution()
assert s.sortTransformedArray(
[-4, -2, 2, 4],
1,
3,
5,
) == [3, 9, 15, 33]
assert s.sortTransformedArray(
[-4, -2, 2, 4],
-1,
3,
5,
) == [-23, -5, 1, 7]
assert s.sortTransformedArray(
[-4, -2, 2, 4],
0,
3,
5,
) == [-7, -1, 11, 17]
assert s.sortTransformedArray(
[-4, -2, 2, 4],
0,
-3,
5,
) == [-7, -1, 11, 17]
assert s.sortTransformedArray(
[1],
2,
3,
4,
) == [9]
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
a > 0 | Upward parabola, fill from the end |
a < 0 | Downward parabola, fill from the front |
a = 0, b > 0 | Linear increasing case |
a = 0, b < 0 | Linear decreasing case |
| One element | Minimum input size |