# LeetCode 663: Equal Tree Partition

## Problem Restatement

We are given the root of a binary tree.

We need to decide whether we can remove exactly one edge so that the tree is split into two smaller trees with equal sums.

Each tree sum is the sum of all node values inside that tree.

Return `True` if such an edge exists.

Otherwise, return `False`.

## Input and Output

| Item | Meaning |
|---|---|
| Input | The root of a binary tree |
| Output | `True` if one edge can split the tree into equal sums |
| Operation | Remove exactly one edge |
| Required result | The two resulting tree sums must be equal |

Example function shape:

```python
def checkEqualTree(root: Optional[TreeNode]) -> bool:
    ...
```

## Examples

Consider this tree:

```text
      5
     / \
    10  10
       /  \
      2    3
```

The total sum is:

```text
5 + 10 + 10 + 2 + 3 = 30
```

If we remove the edge between the root `5` and its left child `10`, the left subtree has sum `10`.

The remaining tree has sum:

```text
30 - 10 = 20
```

That does not work.

But if we remove the edge above the right child `10`, that subtree has sum:

```text
10 + 2 + 3 = 15
```

The remaining tree also has sum:

```text
30 - 15 = 15
```

So the answer is:

```python
True
```

Another example:

```text
      1
     / \
    2   10
       /  \
      2    20
```

The total sum is:

```text
35
```

Since the total sum is odd, it cannot be split into two equal integer sums.

So the answer is:

```python
False
```

## First Thought: Try Removing Every Edge

A direct solution is to try every edge.

For each edge:

1. Remove the edge.
2. Compute the sum of the detached subtree.
3. Compute the sum of the remaining tree.
4. Check whether the two sums are equal.

This works logically, but it repeats many subtree-sum computations.

If the tree has `n` nodes, recomputing sums for many edges can become too slow.

## Key Insight

Removing one edge separates exactly one subtree from the rest of the tree.

Suppose a detached subtree has sum:

```text
s
```

The remaining tree has sum:

```text
total - s
```

We want:

```text
s = total - s
```

So:

```text
2 * s = total
```

That means the target subtree sum must be:

```text
total / 2
```

Therefore:

1. Compute the total sum of the whole tree.
2. If the total sum is odd, return `False`.
3. Check whether there is a proper subtree whose sum is `total // 2`.

The word proper matters. We cannot remove an edge above the root, so the whole tree itself cannot be the selected subtree.

## Subtree Sums

A subtree sum is naturally computed with postorder DFS.

For a node:

```text
subtree_sum = node.val + left_sum + right_sum
```

We can compute every subtree sum and store it.

Then after computing the total sum, we check whether `total // 2` appears among the subtree sums, excluding the whole tree sum.

## Algorithm

Use DFS to compute subtree sums.

There are two common implementations.

The clean version is:

1. Run DFS and store every subtree sum in a list.
2. The last computed sum is the total sum.
3. Remove the total sum from consideration.
4. If the total is even and `total // 2` appears in the remaining sums, return `True`.
5. Otherwise, return `False`.

## Correctness

Every possible edge removal detaches exactly one subtree from the original tree.

For each non-root node, removing the edge from its parent detaches the subtree rooted at that node.

The DFS computes the sum of every subtree in the tree. Therefore, it computes the sum of every subtree that could be detached by removing one edge.

If removing an edge creates two trees with equal sums, then the detached subtree has sum `total // 2`. Since that detached subtree is rooted at some non-root node, its sum appears in the stored subtree sums after excluding the whole tree.

Conversely, if there is a non-root subtree with sum `total // 2`, then removing the edge above that subtree creates one tree with sum `total // 2` and another tree with sum `total - total // 2`, which is also `total // 2`.

Thus, the algorithm returns `True` exactly when a valid equal partition exists.

## Complexity

| Metric | Value | Why |
|---|---|---|
| Time | `O(n)` | Each node is visited once |
| Space | `O(n)` | We store subtree sums and use recursion stack |

The recursion stack is `O(h)`, where `h` is the tree height. In the worst case, `h = n`.

## Implementation

```python
from typing import Optional

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def checkEqualTree(self, root: Optional[TreeNode]) -> bool:
        sums = []

        def dfs(node: Optional[TreeNode]) -> int:
            if node is None:
                return 0

            subtotal = (
                node.val
                + dfs(node.left)
                + dfs(node.right)
            )

            sums.append(subtotal)
            return subtotal

        total = dfs(root)

        sums.pop()

        if total % 2 != 0:
            return False

        return total // 2 in sums
```

## Code Explanation

We store all subtree sums in a list:

```python
sums = []
```

The DFS returns the sum of the subtree rooted at `node`.

```python
def dfs(node: Optional[TreeNode]) -> int:
```

For an empty child, the sum is `0`:

```python
if node is None:
    return 0
```

For a real node, we compute:

```python
subtotal = (
    node.val
    + dfs(node.left)
    + dfs(node.right)
)
```

Then we store that subtree sum:

```python
sums.append(subtotal)
```

After DFS finishes:

```python
total = dfs(root)
```

the last value in `sums` is the sum of the whole tree.

We remove it:

```python
sums.pop()
```

because removing one edge cannot detach the entire tree.

If the total sum is odd, two equal integer sums are impossible:

```python
if total % 2 != 0:
    return False
```

Otherwise, we check whether a proper subtree has half the total sum:

```python
return total // 2 in sums
```

## Testing

```python
def run_tests():
    s = Solution()

    # Tree:
    #       5
    #      / \
    #     10  10
    #        /  \
    #       2    3
    root = TreeNode(5)
    root.left = TreeNode(10)
    root.right = TreeNode(10, TreeNode(2), TreeNode(3))

    assert s.checkEqualTree(root) is True

    # Tree:
    #       1
    #      / \
    #     2   10
    #        /  \
    #       2    20
    root = TreeNode(1)
    root.left = TreeNode(2)
    root.right = TreeNode(10, TreeNode(2), TreeNode(20))

    assert s.checkEqualTree(root) is False

    # Total is 0, and a proper subtree also has sum 0.
    root = TreeNode(0, TreeNode(0), None)

    assert s.checkEqualTree(root) is True

    # Single node cannot be split by removing an edge.
    root = TreeNode(0)

    assert s.checkEqualTree(root) is False

    # Negative values can still form an equal partition.
    root = TreeNode(1, TreeNode(-1), TreeNode(0))

    assert s.checkEqualTree(root) is True

    print("all tests passed")

run_tests()
```

Test meaning:

| Test | Why |
|---|---|
| Standard valid split | Confirms a subtree with half the total is found |
| Odd total sum | Confirms quick rejection |
| Zero total with zero subtree | Checks the important zero-sum case |
| Single node | Cannot remove any edge |
| Negative values | Confirms the sum logic does not assume positivity |

