A clear explanation of counting ordered triples whose bitwise AND is zero using pairwise AND counts.
Problem Restatement
We are given an integer array nums.
We need to count the number of ordered triples of indices (i, j, k) such that:
nums[i] & nums[j] & nums[k] == 0The indices may repeat. For example, (0, 0, 1) is allowed.
The triple is ordered, so (0, 1, 2) and (1, 0, 2) are counted as different triples if both satisfy the condition.
The official constraints are:
| Constraint | Value |
|---|---|
nums.length | 1 <= nums.length <= 1000 |
nums[i] | 0 <= nums[i] < 2^16 |
Source: LeetCode problem statement.
Input and Output
| Item | Meaning |
|---|---|
| Input | An integer array nums |
| Output | Number of ordered triples (i, j, k) |
| Condition | nums[i] & nums[j] & nums[k] == 0 |
| Repeated indices | Allowed |
Function shape:
def countTriplets(nums: list[int]) -> int:
...Examples
Example 1:
nums = [2, 1, 3]In binary:
2 = 10
1 = 01
3 = 11We need ordered triples where the bitwise AND becomes 0.
There are 12 valid triples.
Answer:
12Example 2:
nums = [0]The only ordered triple is:
(0, 0, 0)And:
0 & 0 & 0 == 0Answer:
1First Thought: Check Every Triple
The direct solution is to try every ordered triple.
class Solution:
def countTriplets(self, nums: list[int]) -> int:
n = len(nums)
answer = 0
for i in range(n):
for j in range(n):
for k in range(n):
if nums[i] & nums[j] & nums[k] == 0:
answer += 1
return answerThis follows the problem statement exactly.
Problem With Brute Force
The brute force solution checks n^3 triples.
Since nums.length can be 1000, this can require up to:
1000^3 = 1,000,000,000checks.
That is too slow.
We need to reuse intermediate bitwise AND results.
Key Insight
Bitwise AND is associative:
(nums[i] & nums[j]) & nums[k] == nums[i] & nums[j] & nums[k]So instead of checking every triple directly, we can first count all pairwise AND results.
For every ordered pair (i, j), compute:
pair = nums[i] & nums[j]Store how many pairs produce each pair value.
Then for each num in nums, every pair value pair such that:
pair & num == 0forms valid triples.
If count[pair] = c, then this contributes c triples for the current num.
This changes the problem from counting triples directly to counting pair results first.
Algorithm
Create a frequency map pair_count.
For every ordered pair (a, b) in nums:
- Compute
a & b. - Increment its count in
pair_count.
Then count valid triples:
- For every value
pairinpair_count. - For every number
numinnums. - If
pair & num == 0, addpair_count[pair]to the answer.
Correctness
Every ordered triple (i, j, k) can be split into an ordered pair (i, j) and a third index k.
For the pair (i, j), the algorithm computes:
pair = nums[i] & nums[j]The triple is valid exactly when:
pair & nums[k] == 0The first phase counts how many ordered pairs produce each possible pair value. Therefore, when the second phase finds that pair & nums[k] == 0, all ordered pairs that produced that same pair value form valid triples with the current k.
The algorithm adds exactly that number of pairs.
Every valid ordered triple is counted once: when the pair value from its first two indices is considered with its third value.
No invalid triple is counted, because the algorithm adds a pair count only when the final bitwise AND is zero.
Therefore, the algorithm returns the correct number of ordered triples.
Complexity
Let n = len(nums) and let m be the number of distinct pairwise AND results.
| Metric | Value | Why |
|---|---|---|
| Time | O(n^2 + mn) | Count all ordered pairs, then test each distinct pair result with each number |
| Space | O(m) | Store counts of pairwise AND results |
Since nums[i] < 2^16, there are at most 2^16 possible AND values.
So m <= 2^16.
Implementation
from collections import Counter
class Solution:
def countTriplets(self, nums: list[int]) -> int:
pair_count = Counter()
for a in nums:
for b in nums:
pair_count[a & b] += 1
answer = 0
for pair, count in pair_count.items():
for num in nums:
if pair & num == 0:
answer += count
return answerCode Explanation
We store pairwise AND frequencies:
pair_count = Counter()Then compute all ordered pairs:
for a in nums:
for b in nums:
pair_count[a & b] += 1This includes repeated indices and ordered pairs, which matches the problem.
For example, both (i, j) and (j, i) are counted.
Then we test each pair result against each third number:
for pair, count in pair_count.items():
for num in nums:If the final AND is zero:
if pair & num == 0:then all pairs represented by count are valid with this num:
answer += countFinally, return the total:
return answerTesting
def run_tests():
s = Solution()
assert s.countTriplets([2, 1, 3]) == 12
assert s.countTriplets([0]) == 1
assert s.countTriplets([0, 0]) == 8
assert s.countTriplets([1]) == 0
assert s.countTriplets([1, 2]) == 6
print("all tests passed")
run_tests()| Test | Expected | Why |
|---|---|---|
[2, 1, 3] | 12 | Official-style sample |
[0] | 1 | Single zero forms one valid repeated-index triple |
[0, 0] | 8 | All 2^3 ordered triples are valid |
[1] | 0 | 1 & 1 & 1 is not zero |
[1, 2] | 6 | All triples except all-1 and all-2 become zero |