A clear explanation of finding all binary tree nodes at distance k from a target node by treating the tree as an undirected graph.
Problem Restatement
We are given the root of a binary tree, a target node, and an integer k.
We need to return the values of all nodes whose distance from the target node is exactly k.
Distance means the number of edges on the path between two nodes.
The answer can be returned in any order.
Input and Output
| Item | Meaning |
|---|---|
| Input | root, the root of the binary tree |
| Input | target, the target TreeNode |
| Input | k, the required distance |
| Output | A list of node values at distance exactly k from target |
Function shape:
class Solution:
def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> list[int]:
...Examples
Example 1:
root = [3, 5, 1, 6, 2, 0, 8, None, None, 7, 4]
target = 5
k = 2The tree is:
3
/ \
5 1
/ \ / \
6 2 0 8
/ \
7 4Nodes at distance 2 from node 5 are:
| Node | Path | Distance |
|---|---|---|
7 | 5 -> 2 -> 7 | 2 |
4 | 5 -> 2 -> 4 | 2 |
1 | 5 -> 3 -> 1 | 2 |
So one valid answer is:
[7, 4, 1]The order does not matter.
Example 2:
root = [1]
target = 1
k = 0The target node itself is at distance 0.
So the answer is:
[1]First Thought: Search Downward From Target
If we only needed nodes below the target, we could run DFS from the target and collect nodes at depth k.
But the answer may also include nodes above the target, or nodes in another subtree.
In the first example, node 1 is distance 2 from target 5:
5 -> 3 -> 1That path goes upward to the parent before going downward.
A normal binary tree node has links to its children, but not to its parent.
So we need a way to move upward.
Key Insight
Treat the binary tree as an undirected graph.
Each tree edge connects:
parent <-> childThen the problem becomes:
Starting from target, find all graph nodes at distance exactly k.
To move upward, first build a map from each node to its parent.
Then run BFS from the target.
From any node, we can move to:
- Its left child
- Its right child
- Its parent
We also need a visited set so we do not move back and forth forever.
Algorithm
First, build the parent map.
Use DFS from root.
For every node:
- If it has a left child, record that child’s parent.
- If it has a right child, record that child’s parent.
- Continue DFS.
Then run BFS from target.
Maintain:
queue = deque([(target, 0)])
visited = {target}
answer = []For each node and distance:
- If distance equals
k, add the node value toanswer. - Otherwise, visit its neighbors:
- left child
- right child
- parent
- Add unvisited neighbors to the queue with distance
distance + 1.
Return answer.
Correctness
The parent map gives every node access to its parent. Together with existing left and right child pointers, this represents every tree edge as an undirected connection.
Therefore, BFS from the target explores exactly the nodes reachable by paths in the tree.
BFS processes nodes in increasing distance from the target. When a node is removed from the queue with distance d, that distance is the shortest number of edges from the target to that node.
The algorithm adds a node to the answer exactly when its BFS distance is k.
Since a tree has a unique simple path between any two nodes, and the visited set prevents revisiting nodes through the reverse edge, every node is processed once.
Therefore, the returned list contains exactly all nodes at distance k from the target.
Complexity
Let n be the number of nodes in the tree.
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | We build the parent map once and BFS visits each node at most once |
| Space | O(n) | The parent map, visited set, and queue can store up to n nodes |
Implementation
from collections import deque
from typing import Optional
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None
class Solution:
def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> list[int]:
parent = {}
def build_parent(node: Optional[TreeNode], par: Optional[TreeNode]) -> None:
if node is None:
return
parent[node] = par
build_parent(node.left, node)
build_parent(node.right, node)
build_parent(root, None)
queue = deque([(target, 0)])
visited = {target}
answer = []
while queue:
node, distance = queue.popleft()
if distance == k:
answer.append(node.val)
continue
for neighbor in (node.left, node.right, parent[node]):
if neighbor is not None and neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, distance + 1))
return answerCode Explanation
We first build a dictionary:
parent = {}This maps each node to its parent.
The helper function stores the parent for each node:
def build_parent(node, par):
if node is None:
return
parent[node] = parThen it continues into both children:
build_parent(node.left, node)
build_parent(node.right, node)After the parent map is ready, we start BFS from the target:
queue = deque([(target, 0)])
visited = {target}
answer = []Each queue item stores a node and its distance from the target.
When we pop a node:
node, distance = queue.popleft()if its distance is exactly k, we record it:
if distance == k:
answer.append(node.val)
continueThe continue is useful because we do not need to explore beyond distance k.
Otherwise, we inspect all possible neighbors:
for neighbor in (node.left, node.right, parent[node]):This includes both children and the parent.
If the neighbor exists and was not visited, add it to the BFS queue:
if neighbor is not None and neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, distance + 1))Finally, return all collected values:
return answerTesting
class TreeNode:
def __init__(self, x):
self.val = x
self.left = None
self.right = None
def test_distance_k():
s = Solution()
root = TreeNode(3)
root.left = TreeNode(5)
root.right = TreeNode(1)
root.left.left = TreeNode(6)
root.left.right = TreeNode(2)
root.right.left = TreeNode(0)
root.right.right = TreeNode(8)
root.left.right.left = TreeNode(7)
root.left.right.right = TreeNode(4)
result = s.distanceK(root, root.left, 2)
assert sorted(result) == [1, 4, 7]
single = TreeNode(1)
assert s.distanceK(single, single, 0) == [1]
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(3)
assert sorted(s.distanceK(root, root, 1)) == [2, 3]
assert s.distanceK(root, root.left, 2) == [3]
print("all tests passed")
test_distance_k()Test meaning:
| Test | Why |
|---|---|
| Standard example | Checks upward and downward traversal |
Single node with k = 0 | Checks target itself |
| Target is root | Only child paths are needed |
| Target is leaf | Requires going upward then downward |