Skip to content

LeetCode 863: All Nodes Distance K in Binary Tree

A clear explanation of finding all binary tree nodes at distance k from a target node by treating the tree as an undirected graph.

Problem Restatement

We are given the root of a binary tree, a target node, and an integer k.

We need to return the values of all nodes whose distance from the target node is exactly k.

Distance means the number of edges on the path between two nodes.

The answer can be returned in any order.

Input and Output

ItemMeaning
Inputroot, the root of the binary tree
Inputtarget, the target TreeNode
Inputk, the required distance
OutputA list of node values at distance exactly k from target

Function shape:

class Solution:
    def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> list[int]:
        ...

Examples

Example 1:

root = [3, 5, 1, 6, 2, 0, 8, None, None, 7, 4]
target = 5
k = 2

The tree is:

        3
       / \
      5   1
     / \ / \
    6  2 0  8
      / \
     7   4

Nodes at distance 2 from node 5 are:

NodePathDistance
75 -> 2 -> 72
45 -> 2 -> 42
15 -> 3 -> 12

So one valid answer is:

[7, 4, 1]

The order does not matter.

Example 2:

root = [1]
target = 1
k = 0

The target node itself is at distance 0.

So the answer is:

[1]

First Thought: Search Downward From Target

If we only needed nodes below the target, we could run DFS from the target and collect nodes at depth k.

But the answer may also include nodes above the target, or nodes in another subtree.

In the first example, node 1 is distance 2 from target 5:

5 -> 3 -> 1

That path goes upward to the parent before going downward.

A normal binary tree node has links to its children, but not to its parent.

So we need a way to move upward.

Key Insight

Treat the binary tree as an undirected graph.

Each tree edge connects:

parent <-> child

Then the problem becomes:

Starting from target, find all graph nodes at distance exactly k.

To move upward, first build a map from each node to its parent.

Then run BFS from the target.

From any node, we can move to:

  1. Its left child
  2. Its right child
  3. Its parent

We also need a visited set so we do not move back and forth forever.

Algorithm

First, build the parent map.

Use DFS from root.

For every node:

  1. If it has a left child, record that child’s parent.
  2. If it has a right child, record that child’s parent.
  3. Continue DFS.

Then run BFS from target.

Maintain:

queue = deque([(target, 0)])
visited = {target}
answer = []

For each node and distance:

  1. If distance equals k, add the node value to answer.
  2. Otherwise, visit its neighbors:
    1. left child
    2. right child
    3. parent
  3. Add unvisited neighbors to the queue with distance distance + 1.

Return answer.

Correctness

The parent map gives every node access to its parent. Together with existing left and right child pointers, this represents every tree edge as an undirected connection.

Therefore, BFS from the target explores exactly the nodes reachable by paths in the tree.

BFS processes nodes in increasing distance from the target. When a node is removed from the queue with distance d, that distance is the shortest number of edges from the target to that node.

The algorithm adds a node to the answer exactly when its BFS distance is k.

Since a tree has a unique simple path between any two nodes, and the visited set prevents revisiting nodes through the reverse edge, every node is processed once.

Therefore, the returned list contains exactly all nodes at distance k from the target.

Complexity

Let n be the number of nodes in the tree.

MetricValueWhy
TimeO(n)We build the parent map once and BFS visits each node at most once
SpaceO(n)The parent map, visited set, and queue can store up to n nodes

Implementation

from collections import deque
from typing import Optional

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None

class Solution:
    def distanceK(self, root: TreeNode, target: TreeNode, k: int) -> list[int]:
        parent = {}

        def build_parent(node: Optional[TreeNode], par: Optional[TreeNode]) -> None:
            if node is None:
                return

            parent[node] = par

            build_parent(node.left, node)
            build_parent(node.right, node)

        build_parent(root, None)

        queue = deque([(target, 0)])
        visited = {target}
        answer = []

        while queue:
            node, distance = queue.popleft()

            if distance == k:
                answer.append(node.val)
                continue

            for neighbor in (node.left, node.right, parent[node]):
                if neighbor is not None and neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((neighbor, distance + 1))

        return answer

Code Explanation

We first build a dictionary:

parent = {}

This maps each node to its parent.

The helper function stores the parent for each node:

def build_parent(node, par):
    if node is None:
        return

    parent[node] = par

Then it continues into both children:

build_parent(node.left, node)
build_parent(node.right, node)

After the parent map is ready, we start BFS from the target:

queue = deque([(target, 0)])
visited = {target}
answer = []

Each queue item stores a node and its distance from the target.

When we pop a node:

node, distance = queue.popleft()

if its distance is exactly k, we record it:

if distance == k:
    answer.append(node.val)
    continue

The continue is useful because we do not need to explore beyond distance k.

Otherwise, we inspect all possible neighbors:

for neighbor in (node.left, node.right, parent[node]):

This includes both children and the parent.

If the neighbor exists and was not visited, add it to the BFS queue:

if neighbor is not None and neighbor not in visited:
    visited.add(neighbor)
    queue.append((neighbor, distance + 1))

Finally, return all collected values:

return answer

Testing

class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

def test_distance_k():
    s = Solution()

    root = TreeNode(3)
    root.left = TreeNode(5)
    root.right = TreeNode(1)
    root.left.left = TreeNode(6)
    root.left.right = TreeNode(2)
    root.right.left = TreeNode(0)
    root.right.right = TreeNode(8)
    root.left.right.left = TreeNode(7)
    root.left.right.right = TreeNode(4)

    result = s.distanceK(root, root.left, 2)
    assert sorted(result) == [1, 4, 7]

    single = TreeNode(1)
    assert s.distanceK(single, single, 0) == [1]

    root = TreeNode(1)
    root.left = TreeNode(2)
    root.right = TreeNode(3)
    assert sorted(s.distanceK(root, root, 1)) == [2, 3]

    assert s.distanceK(root, root.left, 2) == [3]

    print("all tests passed")

test_distance_k()

Test meaning:

TestWhy
Standard exampleChecks upward and downward traversal
Single node with k = 0Checks target itself
Target is rootOnly child paths are needed
Target is leafRequires going upward then downward