A clear explanation of flattening a binary tree into a linked list in preorder traversal order using recursive depth-first search.
Problem Restatement
We are given the root of a binary tree.
We need to flatten the tree into a linked list in-place.
The linked list must follow the same order as preorder traversal:
root -> left -> rightAfter flattening:
- Every node’s left child must become
None. - Every node’s right child points to the next node in preorder order.
The official problem specifically requires the flattened structure to use the right pointers like a linked list. (leetcode.com)
For this tree:
1
/ \
2 5
/ \ \
3 4 6The preorder traversal is:
1 -> 2 -> 3 -> 4 -> 5 -> 6So the flattened tree becomes:
1
\
2
\
3
\
4
\
5
\
6Input and Output
| Item | Meaning |
|---|---|
| Input | root, the root node of a binary tree |
| Output | No return value |
| Modification | The tree must be changed in-place |
| Final structure | Uses only right pointers |
| Traversal order | Must follow preorder traversal |
LeetCode gives the TreeNode class:
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = rightThe function shape is:
class Solution:
def flatten(self, root: Optional[TreeNode]) -> None:
...Examples
Consider:
1
/ \
2 5
/ \ \
3 4 6The preorder traversal order is:
1 -> 2 -> 3 -> 4 -> 5 -> 6So after flattening:
1
\
2
\
3
\
4
\
5
\
6Every left pointer becomes:
NoneFor a single-node tree:
1The tree already satisfies the required structure.
For an empty tree:
root = NoneNothing needs to be changed.
First Thought: Build the Preorder Traversal
The required linked list order exactly matches preorder traversal:
root -> left -> rightSo one idea is:
- Store all nodes in preorder order inside a list.
- Connect each node’s
rightpointer to the next node. - Set every
leftpointer toNone.
This works, but it uses extra memory.
The problem asks for an in-place solution.
Key Insight
For every node:
- Flatten the left subtree.
- Flatten the right subtree.
- Insert the flattened left subtree between the current node and the flattened right subtree.
Suppose we have:
1
/ \
2 5After flattening subtrees:
2 -> 3 -> 4
5 -> 6We rearrange pointers into:
1 -> 2 -> 3 -> 4 -> 5 -> 6To do this:
- Save the original right subtree.
- Move the left subtree to the right.
- Set the left pointer to
None. - Find the tail of the new right chain.
- Attach the original right subtree at the tail.
Algorithm
For each node:
- Recursively flatten the left subtree.
- Recursively flatten the right subtree.
- Save the original right subtree.
- Move the left subtree to the right side.
- Set
left = None. - Find the rightmost node of the moved subtree.
- Attach the original right subtree there.
If the current node has no left subtree, nothing needs to be rearranged.
Correctness
The recursive calls flatten the left and right subtrees into preorder-linked structures.
For the current node, preorder traversal requires this order:
current node
left subtree
right subtreeAfter the recursive calls:
- The left subtree is already flattened in preorder order.
- The right subtree is already flattened in preorder order.
The algorithm moves the flattened left subtree to the current node’s right pointer. Then it attaches the flattened right subtree at the end of that chain.
Therefore, the resulting structure follows:
current node
then flattened left subtree
then flattened right subtreewhich exactly matches preorder traversal.
The algorithm also sets every left pointer to None, satisfying the linked-list requirement.
Since the process is applied recursively to every node, the entire tree becomes a valid flattened preorder linked list.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) average, O(n^2) worst-case | Each node is processed once, but repeated tail scans can become expensive in skewed trees |
| Space | O(h) | Recursion stack depth |
Here:
nis the number of nodes.his the tree height.
The worst case occurs when the tree is highly skewed and many repeated right-tail traversals happen.
Implementation
from typing import Optional
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def flatten(self, root: Optional[TreeNode]) -> None:
if root is None:
return
self.flatten(root.left)
self.flatten(root.right)
left_subtree = root.left
right_subtree = root.right
root.left = None
root.right = left_subtree
current = root
while current.right is not None:
current = current.right
current.right = right_subtreeCode Explanation
Handle the empty tree:
if root is None:
returnFlatten the left subtree first:
self.flatten(root.left)Flatten the right subtree:
self.flatten(root.right)Save both subtrees before changing pointers:
left_subtree = root.left
right_subtree = root.rightMove the left subtree to the right:
root.left = None
root.right = left_subtreeNow the flattened left subtree appears immediately after the current node.
Find the tail of this chain:
current = root
while current.right is not None:
current = current.rightAttach the original right subtree:
current.right = right_subtreeThe final structure follows preorder order.
Testing
from typing import Optional
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
class Solution:
def flatten(self, root: Optional[TreeNode]) -> None:
if root is None:
return
self.flatten(root.left)
self.flatten(root.right)
left_subtree = root.left
right_subtree = root.right
root.left = None
root.right = left_subtree
current = root
while current.right is not None:
current = current.right
current.right = right_subtree
def flattened_values(root):
values = []
while root is not None:
values.append(root.val)
assert root.left is None
root = root.right
return values
def run_tests():
s = Solution()
root1 = TreeNode(
1,
TreeNode(2, TreeNode(3), TreeNode(4)),
TreeNode(5, None, TreeNode(6)),
)
s.flatten(root1)
assert flattened_values(root1) == [1, 2, 3, 4, 5, 6]
root2 = TreeNode(1)
s.flatten(root2)
assert flattened_values(root2) == [1]
root3 = None
s.flatten(root3)
assert root3 is None
root4 = TreeNode(
1,
TreeNode(2, TreeNode(3)),
None,
)
s.flatten(root4)
assert flattened_values(root4) == [1, 2, 3]
root5 = TreeNode(
1,
None,
TreeNode(2, None, TreeNode(3)),
)
s.flatten(root5)
assert flattened_values(root5) == [1, 2, 3]
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
| Standard mixed tree | Confirms preorder flattening |
| Single node | Minimum non-empty tree |
| Empty tree | Confirms base case |
| Left-skewed tree | Confirms left subtree movement |
| Right-skewed tree | Confirms existing linked structure remains valid |