A clear explanation of the Sum of Distances in Tree problem using tree DP, subtree sizes, and rerooting.
Problem Restatement
We are given an undirected connected tree with n nodes labeled from 0 to n - 1.
The array edges has n - 1 edges, where each edge connects two nodes.
We need to return an array answer of length n, where:
answer[i]is the sum of distances from node i to every other node in the tree.
Input and Output
| Item | Meaning |
|---|---|
| Input | n and edges |
| Output | Array answer of length n |
| Distance | Number of edges on the path between two nodes |
| Graph type | Undirected connected tree |
| Nodes | Labeled from 0 to n - 1 |
Function shape:
class Solution:
def sumOfDistancesInTree(self, n: int, edges: list[list[int]]) -> list[int]:
...Examples
Example 1:
n = 6
edges = [[0, 1], [0, 2], [2, 3], [2, 4], [2, 5]]The tree looks like this:
0
/ \
1 2
/|\
3 4 5For node 0, the distances are:
| To node | Distance |
|---|---|
1 | 1 |
2 | 1 |
3 | 2 |
4 | 2 |
5 | 2 |
So:
answer[0] = 1 + 1 + 2 + 2 + 2 = 8The full result is:
[8, 12, 6, 10, 10, 10]Example 2:
n = 1
edges = []There is only one node, so there are no other nodes to reach.
The answer is:
[0]Example 3:
n = 2
edges = [[1, 0]]Each node is distance 1 from the other.
The answer is:
[1, 1]First Thought: BFS or DFS From Every Node
A direct method is to run BFS or DFS from every node.
For each starting node:
- Traverse the tree.
- Compute distance to every other node.
- Store the sum.
from collections import deque
class Solution:
def sumOfDistancesInTree(self, n: int, edges: list[list[int]]) -> list[int]:
graph = [[] for _ in range(n)]
for a, b in edges:
graph[a].append(b)
graph[b].append(a)
answer = []
for start in range(n):
seen = [False] * n
queue = deque([(start, 0)])
seen[start] = True
total = 0
while queue:
node, dist = queue.popleft()
total += dist
for nei in graph[node]:
if not seen[nei]:
seen[nei] = True
queue.append((nei, dist + 1))
answer.append(total)
return answerThis is correct, but too slow.
Problem With Brute Force
A tree has n nodes and n - 1 edges.
One BFS or DFS costs:
O(n)Doing that from every node costs:
O(n^2)For large n, this is too slow.
We need to reuse work between neighboring roots.
Key Insight
Root the tree at node 0.
During the first DFS, compute two things:
| Array | Meaning |
|---|---|
count[node] | Number of nodes in node’s subtree |
answer[0] | Sum of distances from root 0 to all nodes |
Then use a second DFS to reroot the answer.
Suppose we know answer[parent], and we want answer[child].
When moving the root from parent to child:
| Group | Distance change |
|---|---|
Nodes inside child’s subtree | Become 1 closer |
Nodes outside child’s subtree | Become 1 farther |
If count[child] nodes become closer, the total decreases by:
count[child]If n - count[child] nodes become farther, the total increases by:
n - count[child]So:
answer[child] = answer[parent] - count[child] + (n - count[child])which is:
answer[child] = answer[parent] + n - 2 * count[child]Algorithm
Build an adjacency list.
First DFS from node 0:
- Track depth from root.
- Add each depth to
answer[0]. - Compute subtree sizes in
count.
Second DFS from node 0:
- For each child of the current node, compute:
answer[child] = answer[node] + n - 2 * count[child] - Recurse into the child.
Return answer.
Walkthrough
Use:
n = 6
edges = [[0, 1], [0, 2], [2, 3], [2, 4], [2, 5]]Root the tree at 0:
0
/ \
1 2
/|\
3 4 5After the first DFS:
| Node | Subtree size |
|---|---|
0 | 6 |
1 | 1 |
2 | 4 |
3 | 1 |
4 | 1 |
5 | 1 |
Also:
answer[0] = 8Now reroot.
From 0 to 1:
answer[1] = answer[0] + n - 2 * count[1]
answer[1] = 8 + 6 - 2 * 1
answer[1] = 12From 0 to 2:
answer[2] = answer[0] + n - 2 * count[2]
answer[2] = 8 + 6 - 2 * 4
answer[2] = 6From 2 to 3:
answer[3] = answer[2] + n - 2 * count[3]
answer[3] = 6 + 6 - 2 * 1
answer[3] = 10Nodes 4 and 5 get the same value:
10Final answer:
[8, 12, 6, 10, 10, 10]Correctness
The first DFS computes correct subtree sizes because each node starts with size 1, then adds the sizes of all child subtrees. Since the graph is a tree, each child subtree is disjoint, so this sum gives exactly the number of nodes in the subtree.
The first DFS also computes answer[0] correctly because it visits every node once and adds its depth from root 0. In a rooted tree, the depth of a node is exactly its distance from root 0.
Now consider an edge from parent to child, where the tree is rooted at 0.
When changing the root from parent to child, every node in child’s subtree becomes one edge closer to the root. There are count[child] such nodes.
Every other node becomes one edge farther from the root. There are n - count[child] such nodes.
Therefore:
answer[child] = answer[parent] - count[child] + (n - count[child])The second DFS applies this formula across every tree edge from parent to child. Since every node except 0 has exactly one parent in the rooted tree, every answer[node] is computed exactly once from a correct parent value.
Therefore, all returned distance sums are correct.
Complexity
| Metric | Value | Why |
|---|---|---|
| Time | O(n) | Each edge is processed a constant number of times |
| Space | O(n) | Adjacency list, answer array, count array, and recursion stack |
Implementation
class Solution:
def sumOfDistancesInTree(self, n: int, edges: list[list[int]]) -> list[int]:
graph = [[] for _ in range(n)]
for a, b in edges:
graph[a].append(b)
graph[b].append(a)
answer = [0] * n
count = [1] * n
def postorder(node: int, parent: int, depth: int) -> None:
answer[0] += depth
for nei in graph[node]:
if nei == parent:
continue
postorder(nei, node, depth + 1)
count[node] += count[nei]
def preorder(node: int, parent: int) -> None:
for nei in graph[node]:
if nei == parent:
continue
answer[nei] = answer[node] + n - 2 * count[nei]
preorder(nei, node)
postorder(0, -1, 0)
preorder(0, -1)
return answerCode Explanation
First, build the graph:
graph = [[] for _ in range(n)]
for a, b in edges:
graph[a].append(b)
graph[b].append(a)The tree is undirected, so each edge is stored both ways.
Then initialize:
answer = [0] * n
count = [1] * nEach node counts itself, so every subtree size starts at 1.
The first DFS computes subtree sizes and answer[0]:
def postorder(node: int, parent: int, depth: int) -> None:
answer[0] += depthThe depth is the distance from node 0.
After visiting a child, we add the child’s subtree size:
count[node] += count[nei]The second DFS computes the rest of the answers:
answer[nei] = answer[node] + n - 2 * count[nei]This is the rerooting formula.
Finally:
return answerreturns the sum of distances for every node.
Testing
def run_tests():
s = Solution()
assert s.sumOfDistancesInTree(
6,
[[0, 1], [0, 2], [2, 3], [2, 4], [2, 5]],
) == [8, 12, 6, 10, 10, 10]
assert s.sumOfDistancesInTree(1, []) == [0]
assert s.sumOfDistancesInTree(
2,
[[1, 0]],
) == [1, 1]
assert s.sumOfDistancesInTree(
4,
[[0, 1], [1, 2], [2, 3]],
) == [6, 4, 4, 6]
assert s.sumOfDistancesInTree(
5,
[[0, 1], [0, 2], [0, 3], [0, 4]],
) == [4, 7, 7, 7, 7]
print("all tests passed")
run_tests()Test meaning:
| Test | Why |
|---|---|
| Standard example | Confirms rerooting on branching tree |
| One node | Confirms empty edge list |
| Two nodes | Confirms smallest nontrivial tree |
| Chain tree | Confirms distance sums along a path |
| Star tree | Confirms many leaves around one center |