Skip to content

LeetCode 341: Flatten Nested List Iterator

A clear explanation of Flatten Nested List Iterator using lazy stack-based flattening.

Problem Restatement

We are given a nested list of integers.

Each element is either:

TypeMeaning
IntegerA single integer
ListA nested list whose elements may also be integers or lists

We need to implement an iterator that returns the integers in flattened left-to-right order.

The class must support:

MethodMeaning
NestedIterator(nestedList)Initializes the iterator
next()Returns the next integer
hasNext()Returns true if there are more integers

LeetCode tests the iterator by repeatedly calling hasNext() and next() until no values remain. The NestedInteger interface provides isInteger(), getInteger(), and getList().

Input and Output

ItemMeaning
InputnestedList, a list of NestedInteger objects
OutputIterator behavior, not a direct returned list
next()Returns the next flattened integer
hasNext()Returns whether another integer exists
Traversal orderLeft to right, depth first

Class shape:

class NestedIterator:
    def __init__(self, nestedList: list[NestedInteger]):
        ...

    def next(self) -> int:
        ...

    def hasNext(self) -> bool:
        ...

Examples

Example 1:

Input: nestedList = [[1,1],2,[1,1]]
Flattened order: [1,1,2,1,1]

Calls return:

next() -> 1
next() -> 1
next() -> 2
next() -> 1
next() -> 1

Example 2:

Input: nestedList = [1,[4,[6]]]
Flattened order: [1,4,6]

Calls return:

next() -> 1
next() -> 4
next() -> 6

First Thought: Flatten Everything First

A simple solution is to run DFS in the constructor.

When we see an integer, append it to an array.

When we see a list, recursively flatten that list.

Then next() and hasNext() are simple array operations.

class NestedIterator:
    def __init__(self, nestedList: list[NestedInteger]):
        self.values = []
        self.index = 0

        def dfs(items):
            for item in items:
                if item.isInteger():
                    self.values.append(item.getInteger())
                else:
                    dfs(item.getList())

        dfs(nestedList)

    def next(self) -> int:
        value = self.values[self.index]
        self.index += 1
        return value

    def hasNext(self) -> bool:
        return self.index < len(self.values)

This works and is easy to understand.

But it flattens the whole structure immediately. If the nested list is large and the caller only asks for the first few values, this does unnecessary work.

A lazy iterator does less work upfront.

Key Insight

Use a stack.

The stack stores the remaining NestedInteger objects we still need to process.

To preserve left-to-right order, push the initial list in reverse order.

For example:

nestedList = [[1,1], 2, [1,1]]

Push in reverse:

top -> [1,1], 2, [1,1]

Now the leftmost item is on top.

The job of hasNext() is to make sure the stack top is an integer.

If the top is a list, pop it and push its contents in reverse order.

Repeat until either:

  1. The stack is empty.
  2. The top is an integer.

Then:

MethodBehavior
hasNext()Prepares the next integer and returns True
next()Pops and returns that integer

Algorithm

Constructor:

  1. Store nestedList on a stack in reverse order.

hasNext():

  1. While the stack is not empty:
    1. Look at the top item.
    2. If it is an integer, return True.
    3. Otherwise, pop the list.
    4. Push its elements in reverse order.
  2. Return False.

next():

  1. Call hasNext() to ensure the top is an integer.
  2. Pop the top item.
  3. Return its integer value.

Walkthrough

Use:

nestedList = [1, [4, [6]]]

Initial stack:

top -> 1, [4,[6]]

Call hasNext().

Top is 1, an integer.

Return True.

Call next().

Pop 1.

Now stack:

top -> [4,[6]]

Call hasNext().

Top is a list, so pop it and push its contents in reverse order:

top -> 4, [6]

Top is now 4, an integer.

Return True.

Call next().

Pop 4.

Now stack:

top -> [6]

Call hasNext().

Top is a list, so pop it and push its contents:

top -> 6

Top is 6.

Call next().

Pop 6.

Stack becomes empty.

Now hasNext() returns False.

The output is:

[1, 4, 6]

Correctness

The stack always represents the remaining unvisited part of the nested list in left-to-right order, with the next item at the top.

At initialization, we push the outer list in reverse order. Therefore, the first original element becomes the top of the stack.

When hasNext() sees an integer at the top, that integer is exactly the next flattened value. Returning True is correct.

When hasNext() sees a list at the top, that list must be expanded before any later item can be visited. The algorithm pops the list and pushes its contents in reverse order. This makes the first element of that list become the new top, preserving left-to-right traversal order.

This expansion repeats until the top is an integer or the stack is empty.

Therefore, every call to next() returns the next integer in the flattened order. Every integer is pushed, exposed, and popped exactly once. Lists are only used to reveal their contents. So the iterator produces exactly the correct flattened sequence.

Complexity

Let N be the total number of NestedInteger objects, including integers and lists.

OperationTimeWhy
ConstructorO(L)Pushes the top-level list of length L
hasNext()Amortized O(1)Each nested object is expanded or checked at most once
next()Amortized O(1)Pops one prepared integer
Full iterationO(N)Every nested object is processed once
SpaceValueWhy
Extra spaceO(D) to O(N)Stack stores pending objects; worst case can hold many elements

Compared with pre-flattening, the lazy stack approach avoids storing all integers in a separate flat array.

Implementation

# """
# This is the interface that allows for creating nested lists.
# You should not implement it, or speculate about its implementation.
# """
# class NestedInteger:
#     def isInteger(self) -> bool:
#         ...
#
#     def getInteger(self) -> int:
#         ...
#
#     def getList(self) -> list["NestedInteger"]:
#         ...

class NestedIterator:
    def __init__(self, nestedList: list[NestedInteger]):
        self.stack = nestedList[::-1]

    def next(self) -> int:
        self.hasNext()
        return self.stack.pop().getInteger()

    def hasNext(self) -> bool:
        while self.stack:
            top = self.stack[-1]

            if top.isInteger():
                return True

            self.stack.pop()

            for item in reversed(top.getList()):
                self.stack.append(item)

        return False

Code Explanation

The constructor stores the top-level list in reverse order:

self.stack = nestedList[::-1]

This makes the first item appear on top of the stack.

In hasNext(), inspect the top:

top = self.stack[-1]

If it is an integer, the iterator is ready:

if top.isInteger():
    return True

If it is a list, expand it:

self.stack.pop()

for item in reversed(top.getList()):
    self.stack.append(item)

The reverse push preserves left-to-right order.

In next():

self.hasNext()
return self.stack.pop().getInteger()

Calling hasNext() first ensures the stack top is an integer.

Testing

LeetCode provides the NestedInteger interface. For local tests, we can mock it.

class NestedInteger:
    def __init__(self, value=None):
        self.value = value
        self.items = None if isinstance(value, int) else []

    def isInteger(self):
        return self.items is None

    def getInteger(self):
        return self.value

    def getList(self):
        return self.items

    def add(self, elem):
        if self.items is None:
            self.items = []
            self.value = None
        self.items.append(elem)

def build(value):
    if isinstance(value, int):
        return NestedInteger(value)

    node = NestedInteger()
    for item in value:
        node.add(build(item))
    return node

def collect(nested):
    iterator = NestedIterator(build(nested).getList())
    result = []

    while iterator.hasNext():
        result.append(iterator.next())

    return result

def run_tests():
    assert collect([[1, 1], 2, [1, 1]]) == [1, 1, 2, 1, 1]
    assert collect([1, [4, [6]]]) == [1, 4, 6]
    assert collect([]) == []
    assert collect([[], [1], []]) == [1]
    assert collect([[[]], [[2]], 3]) == [2, 3]
    assert collect([0, [-1, [2]]]) == [0, -1, 2]

    print("all tests passed")

run_tests()

Test meaning:

TestWhy
[[1,1],2,[1,1]]Standard nested list
[1,[4,[6]]]Deep nesting
[]Empty list
[[],[1],[]]Empty lists around a value
[[[]],[[2]],3]Multiple nested empty levels
[0,[-1,[2]]]Zero and negative values