Skip to content

LeetCode 654: Maximum Binary Tree

A clear explanation of constructing a maximum binary tree recursively using divide and conquer.

Problem Restatement

We are given an array of distinct integers called nums.

We must construct a binary tree using the following rules:

  1. The root is the maximum number in the array.
  2. The left subtree is built recursively from the elements to the left of the maximum number.
  3. The right subtree is built recursively from the elements to the right of the maximum number.

Return the root of the constructed tree.

Input and Output

ItemMeaning
InputAn array of distinct integers
OutputThe root of the constructed maximum binary tree
Root ruleThe root is the maximum value in the current subarray
Left subtreeBuilt from elements left of the maximum
Right subtreeBuilt from elements right of the maximum

Example function shape:

def constructMaximumBinaryTree(nums: List[int]) -> Optional[TreeNode]:
    ...

Examples

Consider:

nums = [3, 2, 1, 6, 0, 5]

The maximum value is:

6

So 6 becomes the root.

The left side is:

[3, 2, 1]

The right side is:

[0, 5]

Now build both sides recursively.

For the left side:

[3, 2, 1]

the maximum is 3.

So 3 becomes the left child of 6.

Its right side is:

[2, 1]

The maximum there is 2.

Then 2 gets right child 1.

For the right side:

[0, 5]

the maximum is 5.

So 5 becomes the right child of 6.

Its left child becomes 0.

The final tree is:

        6
       / \
      3   5
       \  /
        2 0
         \
          1

Another example:

nums = [1]

The maximum is 1, so the tree is simply:

1

First Thought: Follow the Definition Directly

The problem already describes the construction process.

For every subarray:

  1. Find the maximum value.
  2. Create a node with that value.
  3. Recursively build the left subtree.
  4. Recursively build the right subtree.

This is a divide-and-conquer problem.

Each recursive call handles a smaller subarray.

Key Insight

The structure of the tree is completely determined by the maximum element.

Every recursive step splits the array into:

PartMeaning
Elements before the maximumLeft subtree
Maximum elementCurrent root
Elements after the maximumRight subtree

So the recursive function only needs to know the current range of indices.

Instead of slicing arrays repeatedly, we can pass:

left index
right index

This avoids creating extra arrays.

Recursive Definition

Suppose the recursive function is:

build(left, right)

This function builds the maximum binary tree from:

nums[left:right+1]

The steps are:

  1. If the range is empty, return None.
  2. Find the index of the maximum value.
  3. Create a node.
  4. Recursively build the left subtree.
  5. Recursively build the right subtree.
  6. Return the node.

Algorithm

Start with the full array:

build(0, len(nums) - 1)

For every recursive call:

  1. If left > right, the subarray is empty.
  2. Scan the range to find the maximum value index.
  3. Create a node with that value.
  4. Build the left subtree from the left side.
  5. Build the right subtree from the right side.
  6. Return the node.

Correctness

The algorithm follows the exact definition of a maximum binary tree.

For any subarray:

  1. The algorithm finds the maximum value.
  2. It creates the root node using that value.
  3. It recursively constructs the left subtree from all elements before the maximum.
  4. It recursively constructs the right subtree from all elements after the maximum.

These are exactly the rules required by the problem.

The recursion eventually stops because each recursive call uses a strictly smaller subarray. When the subarray becomes empty, the function returns None.

Therefore, every subtree is constructed correctly, and the final tree is the required maximum binary tree.

Complexity

MetricValueWhy
TimeO(n^2) worst caseEach recursive call may scan a large subarray
SpaceO(n)Recursion stack in the worst case

The worst case happens when the array is already sorted.

For example:

[1, 2, 3, 4, 5]

Each recursive step scans almost the entire remaining array.

Implementation

from typing import List, Optional

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def constructMaximumBinaryTree(
        self,
        nums: List[int],
    ) -> Optional[TreeNode]:

        def build(left: int, right: int) -> Optional[TreeNode]:
            if left > right:
                return None

            max_index = left

            for i in range(left + 1, right + 1):
                if nums[i] > nums[max_index]:
                    max_index = i

            root = TreeNode(nums[max_index])

            root.left = build(left, max_index - 1)
            root.right = build(max_index + 1, right)

            return root

        return build(0, len(nums) - 1)

Code Explanation

The recursive function:

build(left, right)

constructs the tree for the subarray:

nums[left:right+1]

If the range is invalid:

if left > right:
    return None

then there are no nodes in this subtree.

We begin by assuming the leftmost value is the maximum:

max_index = left

Then we scan the range:

for i in range(left + 1, right + 1):

and update the maximum index whenever we find a larger value:

if nums[i] > nums[max_index]:
    max_index = i

After finding the maximum value, we create the root node:

root = TreeNode(nums[max_index])

Then we recursively construct both sides.

The left subtree uses everything before the maximum:

root.left = build(left, max_index - 1)

The right subtree uses everything after the maximum:

root.right = build(max_index + 1, right)

Finally, we return the constructed subtree root.

Testing

To test tree problems locally, it is useful to serialize the tree.

def serialize(root):
    if root is None:
        return None

    return [
        root.val,
        serialize(root.left),
        serialize(root.right),
    ]

Example tests:

def run_tests():
    s = Solution()

    root = s.constructMaximumBinaryTree([3, 2, 1, 6, 0, 5])

    assert serialize(root) == [
        6,
        [
            3,
            None,
            [
                2,
                None,
                [
                    1,
                    None,
                    None,
                ],
            ],
        ],
        [
            5,
            [
                0,
                None,
                None,
            ],
            None,
        ],
    ]

    root = s.constructMaximumBinaryTree([1])

    assert serialize(root) == [
        1,
        None,
        None,
    ]

    root = s.constructMaximumBinaryTree([5, 4, 3])

    assert serialize(root) == [
        5,
        None,
        [
            4,
            None,
            [
                3,
                None,
                None,
            ],
        ],
    ]

    root = s.constructMaximumBinaryTree([1, 2, 3])

    assert serialize(root) == [
        3,
        [
            2,
            [
                1,
                None,
                None,
            ],
            None,
        ],
        None,
    ]

    print("all tests passed")

Test meaning:

TestWhy
[3,2,1,6,0,5]Standard example with both left and right recursion
[1]Smallest valid tree
[5,4,3]Descending order creates a right-skewed tree
[1,2,3]Ascending order creates a left-skewed tree