Skip to content

LeetCode 814: Binary Tree Pruning

A postorder DFS solution for removing every binary tree subtree that does not contain a 1.

Problem Restatement

We are given the root of a binary tree.

Every node has value either 0 or 1.

We need to remove every subtree that does not contain a 1.

A subtree means a node together with all of its descendants. If a subtree has only 0 values, that entire subtree should be removed.

Return the root of the pruned tree. The official problem asks us to return the same tree after removing every subtree that does not contain a 1.

Input and Output

ItemMeaning
InputRoot of a binary tree
OutputRoot of the pruned binary tree
Node valuesOnly 0 or 1
Remove conditionA subtree contains no 1

Examples

Example 1:

root = [1, None, 0, 0, 1]

The right child of the root is 0.

That node has two children:

0
1

The left child subtree is only 0, so it is removed.

The right child subtree contains 1, so it stays.

The result is:

[1, None, 0, None, 1]

Example 2:

root = [1, 0, 1, 0, 0, 0, 1]

Every subtree containing only 0 is removed.

The result is:

[1, None, 1, None, 1]

First Thought: Check Every Subtree

A direct approach is to visit each node and check whether its whole subtree contains a 1.

If it does not, remove that subtree.

This works, but it may repeat work. A subtree can be checked many times from different ancestors.

We need a bottom-up method.

Key Insight

A node should be removed only after we know what happened to its children.

So we use postorder DFS:

  1. Prune the left subtree.
  2. Prune the right subtree.
  3. Decide whether the current node should stay.

After both children are pruned, the current node should be removed exactly when:

root.val == 0
root.left is None
root.right is None

That condition means this subtree contains no 1.

Algorithm

Define a recursive function:

pruneTree(root)

If root is None, return None.

Otherwise:

  1. Recursively prune the left child.
  2. Recursively prune the right child.
  3. Assign the pruned children back to root.left and root.right.
  4. If root.val == 0 and both children are now None, return None.
  5. Otherwise, return root.

Correctness

The algorithm processes children before their parent.

For a leaf node, if its value is 0, it contains no 1, so the algorithm returns None. If its value is 1, it is kept.

For an internal node, the recursive calls correctly prune its left and right subtrees. After that, any remaining child subtree contains at least one 1.

If the current node has value 0 and both children are gone, then the whole subtree rooted at this node contains no 1, so it must be removed.

If the current node has value 1, the subtree contains a 1, so it must stay.

If the current node has value 0 but at least one child remains, then some descendant contains a 1, so this subtree must stay.

Therefore, each subtree is removed exactly when it contains no 1.

Complexity

Let n be the number of nodes.

MetricValueWhy
TimeO(n)Each node is visited once
SpaceO(h)Recursion stack height is the tree height

In the worst case, h = n for a skewed tree.

Implementation

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right

class Solution:
    def pruneTree(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
        if root is None:
            return None

        root.left = self.pruneTree(root.left)
        root.right = self.pruneTree(root.right)

        if root.val == 0 and root.left is None and root.right is None:
            return None

        return root

Code Explanation

The base case handles an empty subtree:

if root is None:
    return None

Then we prune the children first:

root.left = self.pruneTree(root.left)
root.right = self.pruneTree(root.right)

This is postorder traversal. The parent is handled after its descendants.

After pruning, we check whether the current node has become a zero leaf:

if root.val == 0 and root.left is None and root.right is None:
    return None

A zero leaf has no 1 in its subtree, so it is removed.

Otherwise, the subtree contains a 1, so we keep it:

return root

Testing

from typing import Optional

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def build_tree(values):
    if not values or values[0] is None:
        return None

    nodes = [None if v is None else TreeNode(v) for v in values]
    child = 1

    for node in nodes:
        if node is None:
            continue

        if child < len(nodes):
            node.left = nodes[child]
            child += 1

        if child < len(nodes):
            node.right = nodes[child]
            child += 1

    return nodes[0]

def serialize(root):
    if root is None:
        return []

    queue = [root]
    result = []

    while queue:
        node = queue.pop(0)

        if node is None:
            result.append(None)
            continue

        result.append(node.val)
        queue.append(node.left)
        queue.append(node.right)

    while result and result[-1] is None:
        result.pop()

    return result

def run_tests():
    s = Solution()

    root = build_tree([1, None, 0, 0, 1])
    assert serialize(s.pruneTree(root)) == [1, None, 0, None, 1]

    root = build_tree([1, 0, 1, 0, 0, 0, 1])
    assert serialize(s.pruneTree(root)) == [1, None, 1, None, 1]

    root = build_tree([0])
    assert serialize(s.pruneTree(root)) == []

    root = build_tree([1])
    assert serialize(s.pruneTree(root)) == [1]

    root = build_tree([0, 0, 0])
    assert serialize(s.pruneTree(root)) == []

    print("all tests passed")

run_tests()
TestWhy
[1,None,0,0,1]Removes only the all-zero child subtree
[1,0,1,0,0,0,1]Removes several zero-only subtrees
[0]Entire tree is removed
[1]Single node containing 1 stays
[0,0,0]All-zero tree is removed