Skip to content

LeetCode 663: Equal Tree Partition

A clear explanation of checking whether a binary tree can be split into two equal-sum trees by removing one edge.

Problem Restatement

We are given the root of a binary tree.

We need to decide whether we can remove exactly one edge so that the tree is split into two smaller trees with equal sums.

Each tree sum is the sum of all node values inside that tree.

Return True if such an edge exists.

Otherwise, return False.

Input and Output

ItemMeaning
InputThe root of a binary tree
OutputTrue if one edge can split the tree into equal sums
OperationRemove exactly one edge
Required resultThe two resulting tree sums must be equal

Example function shape:

def checkEqualTree(root: Optional[TreeNode]) -> bool:
    ...

Examples

Consider this tree:

      5
     / \
    10  10
       /  \
      2    3

The total sum is:

5 + 10 + 10 + 2 + 3 = 30

If we remove the edge between the root 5 and its left child 10, the left subtree has sum 10.

The remaining tree has sum:

30 - 10 = 20

That does not work.

But if we remove the edge above the right child 10, that subtree has sum:

10 + 2 + 3 = 15

The remaining tree also has sum:

30 - 15 = 15

So the answer is:

True

Another example:

      1
     / \
    2   10
       /  \
      2    20

The total sum is:

35

Since the total sum is odd, it cannot be split into two equal integer sums.

So the answer is:

False

First Thought: Try Removing Every Edge

A direct solution is to try every edge.

For each edge:

  1. Remove the edge.
  2. Compute the sum of the detached subtree.
  3. Compute the sum of the remaining tree.
  4. Check whether the two sums are equal.

This works logically, but it repeats many subtree-sum computations.

If the tree has n nodes, recomputing sums for many edges can become too slow.

Key Insight

Removing one edge separates exactly one subtree from the rest of the tree.

Suppose a detached subtree has sum:

s

The remaining tree has sum:

total - s

We want:

s = total - s

So:

2 * s = total

That means the target subtree sum must be:

total / 2

Therefore:

  1. Compute the total sum of the whole tree.
  2. If the total sum is odd, return False.
  3. Check whether there is a proper subtree whose sum is total // 2.

The word proper matters. We cannot remove an edge above the root, so the whole tree itself cannot be the selected subtree.

Subtree Sums

A subtree sum is naturally computed with postorder DFS.

For a node:

subtree_sum = node.val + left_sum + right_sum

We can compute every subtree sum and store it.

Then after computing the total sum, we check whether total // 2 appears among the subtree sums, excluding the whole tree sum.

Algorithm

Use DFS to compute subtree sums.

There are two common implementations.

The clean version is:

  1. Run DFS and store every subtree sum in a list.
  2. The last computed sum is the total sum.
  3. Remove the total sum from consideration.
  4. If the total is even and total // 2 appears in the remaining sums, return True.
  5. Otherwise, return False.

Correctness

Every possible edge removal detaches exactly one subtree from the original tree.

For each non-root node, removing the edge from its parent detaches the subtree rooted at that node.

The DFS computes the sum of every subtree in the tree. Therefore, it computes the sum of every subtree that could be detached by removing one edge.

If removing an edge creates two trees with equal sums, then the detached subtree has sum total // 2. Since that detached subtree is rooted at some non-root node, its sum appears in the stored subtree sums after excluding the whole tree.

Conversely, if there is a non-root subtree with sum total // 2, then removing the edge above that subtree creates one tree with sum total // 2 and another tree with sum total - total // 2, which is also total // 2.

Thus, the algorithm returns True exactly when a valid equal partition exists.

Complexity

MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(n)We store subtree sums and use recursion stack

The recursion stack is O(h), where h is the tree height. In the worst case, h = n.

Implementation

from typing import Optional

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def checkEqualTree(self, root: Optional[TreeNode]) -> bool:
        sums = []

        def dfs(node: Optional[TreeNode]) -> int:
            if node is None:
                return 0

            subtotal = (
                node.val
                + dfs(node.left)
                + dfs(node.right)
            )

            sums.append(subtotal)
            return subtotal

        total = dfs(root)

        sums.pop()

        if total % 2 != 0:
            return False

        return total // 2 in sums

Code Explanation

We store all subtree sums in a list:

sums = []

The DFS returns the sum of the subtree rooted at node.

def dfs(node: Optional[TreeNode]) -> int:

For an empty child, the sum is 0:

if node is None:
    return 0

For a real node, we compute:

subtotal = (
    node.val
    + dfs(node.left)
    + dfs(node.right)
)

Then we store that subtree sum:

sums.append(subtotal)

After DFS finishes:

total = dfs(root)

the last value in sums is the sum of the whole tree.

We remove it:

sums.pop()

because removing one edge cannot detach the entire tree.

If the total sum is odd, two equal integer sums are impossible:

if total % 2 != 0:
    return False

Otherwise, we check whether a proper subtree has half the total sum:

return total // 2 in sums

Testing

def run_tests():
    s = Solution()

    # Tree:
    #       5
    #      / \
    #     10  10
    #        /  \
    #       2    3
    root = TreeNode(5)
    root.left = TreeNode(10)
    root.right = TreeNode(10, TreeNode(2), TreeNode(3))

    assert s.checkEqualTree(root) is True

    # Tree:
    #       1
    #      / \
    #     2   10
    #        /  \
    #       2    20
    root = TreeNode(1)
    root.left = TreeNode(2)
    root.right = TreeNode(10, TreeNode(2), TreeNode(20))

    assert s.checkEqualTree(root) is False

    # Total is 0, and a proper subtree also has sum 0.
    root = TreeNode(0, TreeNode(0), None)

    assert s.checkEqualTree(root) is True

    # Single node cannot be split by removing an edge.
    root = TreeNode(0)

    assert s.checkEqualTree(root) is False

    # Negative values can still form an equal partition.
    root = TreeNode(1, TreeNode(-1), TreeNode(0))

    assert s.checkEqualTree(root) is True

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
Standard valid splitConfirms a subtree with half the total is found
Odd total sumConfirms quick rejection
Zero total with zero subtreeChecks the important zero-sum case
Single nodeCannot remove any edge
Negative valuesConfirms the sum logic does not assume positivity