A dynamic programming solution for counting binary trees where every non-leaf node is the product of its children.
Problem Restatement
We are given an array arr of unique integers. Every value is greater than 1.
We need to count how many binary trees can be built using values from arr.
Each value may be used any number of times.
For every non-leaf node:
node.val == node.left.val * node.right.valReturn the number of possible binary trees modulo:
10**9 + 7The left and right child order matters, so (2, 5) and (5, 2) are different child arrangements. The official examples include arr = [2, 4] with output 3, and arr = [2, 4, 5, 10] with output 7.
Input and Output
| Item | Meaning |
|---|---|
| Input | An array arr of unique integers greater than 1 |
| Output | Number of valid binary trees |
| Reuse rule | Values may be reused any number of times |
| Parent rule | A non-leaf value equals product of its two children |
| Modulo | 10**9 + 7 |
Examples
Example 1:
arr = [2, 4]Valid trees are:
[2]
[4]
[4, 2, 2]The answer is:
3Example 2:
arr = [2, 4, 5, 10]Single-node trees:
[2], [4], [5], [10]Factor trees:
[4, 2, 2]
[10, 2, 5]
[10, 5, 2]The answer is:
7First Thought: Build Trees Recursively
For each value, we could try all possible left and right children.
If their product equals the root value, recursively build every left subtree and every right subtree.
This matches the definition, but many subproblems repeat.
For example, once we know how many trees can have root 2, that count can be reused whenever 2 appears as a child.
So we should store the count for each root value.
Key Insight
Let:
dp[x] = number of valid trees with root value xEvery value can always form a single-node tree, so:
dp[x] starts at 1For a non-leaf tree rooted at x, we need two child values a and b such that:
a * b == xIf both a and b exist in arr, then:
dp[x] += dp[a] * dp[b]because every valid left subtree rooted at a can pair with every valid right subtree rooted at b.
Sorting helps because factors of x are smaller than x, so their DP values are already computed.
Algorithm
Sort arr.
Create a hash map from value to index:
index = {value: i for i, value in enumerate(arr)}Create a DP array:
dp = [1] * len(arr)Each dp[i] starts at 1 for the single-node tree.
For each root value arr[i]:
- Try each smaller value
arr[j]as the left child. - If
arr[i] % arr[j] != 0, skip it. - Otherwise compute:
right = arr[i] // arr[j]- If
rightexists in the array, add:
dp[j] * dp[index[right]]to dp[i].
Return:
sum(dp) % MODCorrectness
For every value x, dp[x] starts with one tree: the single-node tree containing only x.
For a non-leaf tree rooted at x, its left child root must be some value a in arr, and its right child root must be b = x / a. The tree is valid only when x is divisible by a and b also exists in arr.
The algorithm checks every possible left child value a. Whenever the matching right child b exists, it adds dp[a] * dp[b], which counts every combination of valid left subtree and valid right subtree.
Because the loop treats a as the left child, the reversed case is counted separately when applicable. This correctly handles ordered binary trees.
Sorting ensures that when computing dp[x], all factor roots are smaller than x, so their counts have already been computed.
Therefore, each dp[x] contains exactly the number of valid trees rooted at x. Summing all dp[x] gives the number of valid trees with any allowed root.
Complexity
Let n = len(arr).
| Metric | Value | Why |
|---|---|---|
| Time | O(n^2) | For each root, we try earlier values as possible factors |
| Space | O(n) | DP array and value-to-index map |
Implementation
class Solution:
def numFactoredBinaryTrees(self, arr: list[int]) -> int:
MOD = 10**9 + 7
arr.sort()
n = len(arr)
index = {value: i for i, value in enumerate(arr)}
dp = [1] * n
for i, root in enumerate(arr):
for j in range(i):
left = arr[j]
if root % left != 0:
continue
right = root // left
if right in index:
dp[i] += dp[j] * dp[index[right]]
dp[i] %= MOD
return sum(dp) % MODCode Explanation
We sort the array first:
arr.sort()This lets us compute smaller factor roots before larger parent roots.
The index map lets us check whether a factor exists:
index = {value: i for i, value in enumerate(arr)}Each value can form one single-node tree:
dp = [1] * nFor each root, we try every smaller value as the left child:
for i, root in enumerate(arr):
for j in range(i):If left cannot divide root, it cannot be a child factor:
if root % left != 0:
continueOtherwise, the matching right child value is:
right = root // leftIf that value exists, we add all combinations of left and right subtrees:
dp[i] += dp[j] * dp[index[right]]Finally, sum all possible root counts:
return sum(dp) % MODTesting
def run_tests():
s = Solution()
assert s.numFactoredBinaryTrees([2, 4]) == 3
assert s.numFactoredBinaryTrees([2, 4, 5, 10]) == 7
assert s.numFactoredBinaryTrees([18, 3, 6, 2]) == 12
assert s.numFactoredBinaryTrees([2]) == 1
assert s.numFactoredBinaryTrees([2, 3, 6, 12]) == 12
print("all tests passed")
run_tests()| Test | Why |
|---|---|
[2,4] | Official small example |
[2,4,5,10] | Counts ordered children |
[18,3,6,2] | Sorting should make input order irrelevant |
[2] | Single-node tree only |
[2,3,6,12] | Multiple factor combinations |