A clear explanation of finding the nth magical number using binary search, greatest common divisor, least common multiple, and inclusion-exclusion.
Problem Restatement
A positive integer is called magical if it is divisible by either a or b.
Given three integers n, a, and b, return the nth magical number.
Because the answer can be very large, return it modulo:
10**9 + 7The constraints are large: n can be up to 10^9, while a and b can be up to 4 * 10^4. So we cannot generate magical numbers one by one. The problem statement defines magical numbers this way and asks for the nth one modulo 10^9 + 7.
Input and Output
| Item | Meaning |
|---|---|
| Input | Three integers n, a, and b |
| Output | The nth positive integer divisible by a or b |
| Modulo | Return the answer modulo 10^9 + 7 |
| Constraint | 1 <= n <= 10^9 |
| Constraint | 2 <= a, b <= 4 * 10^4 |
Example function shape:
def nthMagicalNumber(self, n: int, a: int, b: int) -> int:
...Examples
Example 1:
Input: n = 1, a = 2, b = 3
Output: 2The magical numbers are:
2, 3, 4, 6, 8, 9, ...The first magical number is 2.
Example 2:
Input: n = 4, a = 2, b = 3
Output: 6The magical numbers are:
2, 3, 4, 6, 8, 9, ...The fourth magical number is 6.
Example 3:
Input: n = 5, a = 2, b = 4
Output: 10Every number divisible by 4 is also divisible by 2.
So the magical numbers are simply multiples of 2:
2, 4, 6, 8, 10, ...The fifth magical number is 10.
First Thought: Generate Magical Numbers
A direct idea is to generate all positive integers and count the ones divisible by a or b.
For each number x, we check:
x % a == 0 or x % b == 0When we have found n magical numbers, we return the current number.
Code idea:
class Solution:
def nthMagicalNumber(self, n: int, a: int, b: int) -> int:
count = 0
x = 1
while True:
if x % a == 0 or x % b == 0:
count += 1
if count == n:
return x % (10**9 + 7)
x += 1This is correct for small inputs, but it is far too slow for the given constraints.
Problem With Direct Generation
The value of n can be as large as 10^9.
If we count magical numbers one by one, the loop may need billions of iterations.
We need a way to count how many magical numbers are at or below some value x, without generating them individually.
Key Insight
For any integer x, we can count how many magical numbers are <= x.
Numbers divisible by a:
x // aNumbers divisible by b:
x // bBut numbers divisible by both a and b are counted twice.
The numbers divisible by both are exactly the numbers divisible by:
lcm(a, b)So the count is:
count(x) = x // a + x // b - x // lcm(a, b)This is inclusion-exclusion.
As x gets larger, count(x) never decreases. That gives us a monotonic function, so we can binary search for the smallest x such that:
count(x) >= nThat smallest x is the nth magical number.
Algorithm
First compute the least common multiple:
lcm_ab = a * b // gcd(a, b)Then binary search over possible answers.
The smallest possible answer is 1.
A safe largest possible answer is:
n * min(a, b)This is safe because every multiple of min(a, b) is magical. Therefore, the nth magical number cannot be larger than the nth multiple of the smaller divisor.
During binary search:
- Let
midbe the middle of the current range. - Count how many magical numbers are
<= mid. - If the count is at least
n,midmay be the answer, so move left. - Otherwise,
midis too small, so move right.
When the search ends, left is the smallest valid answer.
Return:
left % MODCorrectness
Define:
count(x) = x // a + x // b - x // lcm(a, b)This counts exactly the number of positive integers at most x that are divisible by a or b.
The term x // a counts multiples of a.
The term x // b counts multiples of b.
Numbers divisible by both are counted in both terms, so they are counted twice. These numbers are precisely the multiples of lcm(a, b), so subtracting x // lcm(a, b) removes the duplicate count.
Therefore, count(x) is correct.
The function count(x) is monotonic. If x increases, the set of positive integers at most x can only gain elements. It never loses magical numbers.
Because of this monotonicity, binary search can find the smallest integer x such that count(x) >= n.
Let this smallest integer be ans.
Since count(ans) >= n, there are at least n magical numbers less than or equal to ans.
Since ans is the smallest such integer, count(ans - 1) < n.
So before ans, there are fewer than n magical numbers. At ans, the count reaches at least n.
This means ans itself is the nth magical number.
Thus, the algorithm returns the required answer.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(log(n * min(a, b))) | Binary search over the answer range |
| Space | O(1) | Only a few integer variables are stored |
The count function is O(1).
The gcd computation is O(log(min(a, b))), which is small compared with the binary search bound.
Implementation
import math
class Solution:
def nthMagicalNumber(self, n: int, a: int, b: int) -> int:
MOD = 10**9 + 7
lcm_ab = a * b // math.gcd(a, b)
left = 1
right = n * min(a, b)
while left < right:
mid = (left + right) // 2
count = mid // a + mid // b - mid // lcm_ab
if count >= n:
right = mid
else:
left = mid + 1
return left % MODCode Explanation
We import math because we need math.gcd:
import mathThe modulo is required by the problem:
MOD = 10**9 + 7We compute the least common multiple:
lcm_ab = a * b // math.gcd(a, b)This lets us count numbers divisible by both a and b.
The search range is:
left = 1
right = n * min(a, b)The upper bound is valid because every multiple of the smaller divisor is magical.
Inside the binary search:
mid = (left + right) // 2We count magical numbers up to mid:
count = mid // a + mid // b - mid // lcm_abIf there are already at least n magical numbers by mid, then mid is large enough:
if count >= n:
right = midOtherwise, mid is too small:
else:
left = mid + 1At the end, left is the smallest number with at least n magical numbers before or at it.
So we return:
return left % MODTesting
def run_tests():
s = Solution()
assert s.nthMagicalNumber(1, 2, 3) == 2
assert s.nthMagicalNumber(4, 2, 3) == 6
assert s.nthMagicalNumber(5, 2, 4) == 10
assert s.nthMagicalNumber(3, 6, 4) == 8
assert s.nthMagicalNumber(10, 2, 3) == 15
assert s.nthMagicalNumber(1, 40000, 40000) == 40000
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
n = 1, a = 2, b = 3 | Smallest position |
n = 4, a = 2, b = 3 | Standard mixed multiples |
n = 5, a = 2, b = 4 | One divisor is a multiple of the other |
n = 3, a = 6, b = 4 | Overlap through LCM |
n = 10, a = 2, b = 3 | Larger sequence check |
n = 1, a = 40000, b = 40000 | Equal large divisors |