Skip to content

LeetCode 878: Nth Magical Number

A clear explanation of finding the nth magical number using binary search, greatest common divisor, least common multiple, and inclusion-exclusion.

Problem Restatement

A positive integer is called magical if it is divisible by either a or b.

Given three integers n, a, and b, return the nth magical number.

Because the answer can be very large, return it modulo:

10**9 + 7

The constraints are large: n can be up to 10^9, while a and b can be up to 4 * 10^4. So we cannot generate magical numbers one by one. The problem statement defines magical numbers this way and asks for the nth one modulo 10^9 + 7.

Input and Output

ItemMeaning
InputThree integers n, a, and b
OutputThe nth positive integer divisible by a or b
ModuloReturn the answer modulo 10^9 + 7
Constraint1 <= n <= 10^9
Constraint2 <= a, b <= 4 * 10^4

Example function shape:

def nthMagicalNumber(self, n: int, a: int, b: int) -> int:
    ...

Examples

Example 1:

Input: n = 1, a = 2, b = 3
Output: 2

The magical numbers are:

2, 3, 4, 6, 8, 9, ...

The first magical number is 2.

Example 2:

Input: n = 4, a = 2, b = 3
Output: 6

The magical numbers are:

2, 3, 4, 6, 8, 9, ...

The fourth magical number is 6.

Example 3:

Input: n = 5, a = 2, b = 4
Output: 10

Every number divisible by 4 is also divisible by 2.

So the magical numbers are simply multiples of 2:

2, 4, 6, 8, 10, ...

The fifth magical number is 10.

First Thought: Generate Magical Numbers

A direct idea is to generate all positive integers and count the ones divisible by a or b.

For each number x, we check:

x % a == 0 or x % b == 0

When we have found n magical numbers, we return the current number.

Code idea:

class Solution:
    def nthMagicalNumber(self, n: int, a: int, b: int) -> int:
        count = 0
        x = 1

        while True:
            if x % a == 0 or x % b == 0:
                count += 1

                if count == n:
                    return x % (10**9 + 7)

            x += 1

This is correct for small inputs, but it is far too slow for the given constraints.

Problem With Direct Generation

The value of n can be as large as 10^9.

If we count magical numbers one by one, the loop may need billions of iterations.

We need a way to count how many magical numbers are at or below some value x, without generating them individually.

Key Insight

For any integer x, we can count how many magical numbers are <= x.

Numbers divisible by a:

x // a

Numbers divisible by b:

x // b

But numbers divisible by both a and b are counted twice.

The numbers divisible by both are exactly the numbers divisible by:

lcm(a, b)

So the count is:

count(x) = x // a + x // b - x // lcm(a, b)

This is inclusion-exclusion.

As x gets larger, count(x) never decreases. That gives us a monotonic function, so we can binary search for the smallest x such that:

count(x) >= n

That smallest x is the nth magical number.

Algorithm

First compute the least common multiple:

lcm_ab = a * b // gcd(a, b)

Then binary search over possible answers.

The smallest possible answer is 1.

A safe largest possible answer is:

n * min(a, b)

This is safe because every multiple of min(a, b) is magical. Therefore, the nth magical number cannot be larger than the nth multiple of the smaller divisor.

During binary search:

  1. Let mid be the middle of the current range.
  2. Count how many magical numbers are <= mid.
  3. If the count is at least n, mid may be the answer, so move left.
  4. Otherwise, mid is too small, so move right.

When the search ends, left is the smallest valid answer.

Return:

left % MOD

Correctness

Define:

count(x) = x // a + x // b - x // lcm(a, b)

This counts exactly the number of positive integers at most x that are divisible by a or b.

The term x // a counts multiples of a.

The term x // b counts multiples of b.

Numbers divisible by both are counted in both terms, so they are counted twice. These numbers are precisely the multiples of lcm(a, b), so subtracting x // lcm(a, b) removes the duplicate count.

Therefore, count(x) is correct.

The function count(x) is monotonic. If x increases, the set of positive integers at most x can only gain elements. It never loses magical numbers.

Because of this monotonicity, binary search can find the smallest integer x such that count(x) >= n.

Let this smallest integer be ans.

Since count(ans) >= n, there are at least n magical numbers less than or equal to ans.

Since ans is the smallest such integer, count(ans - 1) < n.

So before ans, there are fewer than n magical numbers. At ans, the count reaches at least n.

This means ans itself is the nth magical number.

Thus, the algorithm returns the required answer.

Complexity

MetricValueWhy
TimeO(log(n * min(a, b)))Binary search over the answer range
SpaceO(1)Only a few integer variables are stored

The count function is O(1).

The gcd computation is O(log(min(a, b))), which is small compared with the binary search bound.

Implementation

import math

class Solution:
    def nthMagicalNumber(self, n: int, a: int, b: int) -> int:
        MOD = 10**9 + 7

        lcm_ab = a * b // math.gcd(a, b)

        left = 1
        right = n * min(a, b)

        while left < right:
            mid = (left + right) // 2

            count = mid // a + mid // b - mid // lcm_ab

            if count >= n:
                right = mid
            else:
                left = mid + 1

        return left % MOD

Code Explanation

We import math because we need math.gcd:

import math

The modulo is required by the problem:

MOD = 10**9 + 7

We compute the least common multiple:

lcm_ab = a * b // math.gcd(a, b)

This lets us count numbers divisible by both a and b.

The search range is:

left = 1
right = n * min(a, b)

The upper bound is valid because every multiple of the smaller divisor is magical.

Inside the binary search:

mid = (left + right) // 2

We count magical numbers up to mid:

count = mid // a + mid // b - mid // lcm_ab

If there are already at least n magical numbers by mid, then mid is large enough:

if count >= n:
    right = mid

Otherwise, mid is too small:

else:
    left = mid + 1

At the end, left is the smallest number with at least n magical numbers before or at it.

So we return:

return left % MOD

Testing

def run_tests():
    s = Solution()

    assert s.nthMagicalNumber(1, 2, 3) == 2
    assert s.nthMagicalNumber(4, 2, 3) == 6
    assert s.nthMagicalNumber(5, 2, 4) == 10
    assert s.nthMagicalNumber(3, 6, 4) == 8
    assert s.nthMagicalNumber(10, 2, 3) == 15
    assert s.nthMagicalNumber(1, 40000, 40000) == 40000

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
n = 1, a = 2, b = 3Smallest position
n = 4, a = 2, b = 3Standard mixed multiples
n = 5, a = 2, b = 4One divisor is a multiple of the other
n = 3, a = 6, b = 4Overlap through LCM
n = 10, a = 2, b = 3Larger sequence check
n = 1, a = 40000, b = 40000Equal large divisors