A postorder DFS solution for removing every binary tree subtree that does not contain a 1.
Problem Restatement
We are given the root of a binary tree.
Every node has value either 0 or 1.
We need to remove every subtree that does not contain a 1.
A subtree means a node together with all of its descendants. If a subtree has only 0 values, that entire subtree should be removed.
Return the root of the pruned tree. The official problem asks us to return the same tree after removing every subtree that does not contain a 1.
Input and Output
| Item | Meaning |
|---|---|
| Input | Root of a binary tree |
| Output | Root of the pruned binary tree |
| Node values | Only 0 or 1 |
| Remove condition | A subtree contains no 1 |
Examples
Example 1:
root = [1, None, 0, 0, 1]The right child of the root is 0.
That node has two children:
0
1The left child subtree is only 0, so it is removed.
The right child subtree contains 1, so it stays.
The result is:
[1, None, 0, None, 1]Example 2:
root = [1, 0, 1, 0, 0, 0, 1]Every subtree containing only 0 is removed.
The result is:
[1, None, 1, None, 1]First Thought: Check Every Subtree
A direct approach is to visit each node and check whether its whole subtree contains a 1.
If it does not, remove that subtree.
This works, but it may repeat work. A subtree can be checked many times from different ancestors.
We need a bottom-up method.
Key Insight
A node should be removed only after we know what happened to its children.
So we use postorder DFS:
- Prune the left subtree.
- Prune the right subtree.
- Decide whether the current node should stay.
After both children are pruned, the current node should be removed exactly when:
root.val == 0
root.left is None
root.right is NoneThat condition means this subtree contains no 1.
Algorithm
Define a recursive function:
pruneTree(root)If root is None, return None.
Otherwise:
- Recursively prune the left child.
- Recursively prune the right child.
- Assign the pruned children back to
root.leftandroot.right. - If
root.val == 0and both children are nowNone, returnNone. - Otherwise, return
root.
Correctness
The algorithm processes children before their parent.
For a leaf node, if its value is 0, it contains no 1, so the algorithm returns None. If its value is 1, it is kept.
For an internal node, the recursive calls correctly prune its left and right subtrees. After that, any remaining child subtree contains at least one 1.
If the current node has value 0 and both children are gone, then the whole subtree rooted at this node contains no 1, so it must be removed.
If the current node has value 1, the subtree contains a 1, so it must stay.
If the current node has value 0 but at least one child remains, then some descendant contains a 1, so this subtree must stay.
Therefore, each subtree is removed exactly when it contains no 1.
Complexity
Let n be the number of nodes.
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each node is visited once |
| Space | O(h) | Recursion stack height is the tree height |
In the worst case, h = n for a skewed tree.
Implementation
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def pruneTree(self, root: Optional[TreeNode]) -> Optional[TreeNode]:
if root is None:
return None
root.left = self.pruneTree(root.left)
root.right = self.pruneTree(root.right)
if root.val == 0 and root.left is None and root.right is None:
return None
return rootCode Explanation
The base case handles an empty subtree:
if root is None:
return NoneThen we prune the children first:
root.left = self.pruneTree(root.left)
root.right = self.pruneTree(root.right)This is postorder traversal. The parent is handled after its descendants.
After pruning, we check whether the current node has become a zero leaf:
if root.val == 0 and root.left is None and root.right is None:
return NoneA zero leaf has no 1 in its subtree, so it is removed.
Otherwise, the subtree contains a 1, so we keep it:
return rootTesting
from typing import Optional
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def build_tree(values):
if not values or values[0] is None:
return None
nodes = [None if v is None else TreeNode(v) for v in values]
child = 1
for node in nodes:
if node is None:
continue
if child < len(nodes):
node.left = nodes[child]
child += 1
if child < len(nodes):
node.right = nodes[child]
child += 1
return nodes[0]
def serialize(root):
if root is None:
return []
queue = [root]
result = []
while queue:
node = queue.pop(0)
if node is None:
result.append(None)
continue
result.append(node.val)
queue.append(node.left)
queue.append(node.right)
while result and result[-1] is None:
result.pop()
return result
def run_tests():
s = Solution()
root = build_tree([1, None, 0, 0, 1])
assert serialize(s.pruneTree(root)) == [1, None, 0, None, 1]
root = build_tree([1, 0, 1, 0, 0, 0, 1])
assert serialize(s.pruneTree(root)) == [1, None, 1, None, 1]
root = build_tree([0])
assert serialize(s.pruneTree(root)) == []
root = build_tree([1])
assert serialize(s.pruneTree(root)) == [1]
root = build_tree([0, 0, 0])
assert serialize(s.pruneTree(root)) == []
print("all tests passed")
run_tests()| Test | Why |
|---|---|
[1,None,0,0,1] | Removes only the all-zero child subtree |
[1,0,1,0,0,0,1] | Removes several zero-only subtrees |
[0] | Entire tree is removed |
[1] | Single node containing 1 stays |
[0,0,0] | All-zero tree is removed |