# Kernel Fusion

## Kernel Fusion

Kernel fusion combines several small operations into one larger executable unit.

In tensor programs, a kernel is a unit of work launched on a device such as a CPU, GPU, or TPU. A simple expression may produce many kernels if each primitive operation is executed separately.

```text
y = exp(x)
z = y + 1
w = log(z)
```

A naive runtime may execute three kernels:

```text
kernel 1: y = exp(x)
kernel 2: z = y + 1
kernel 3: w = log(z)
```

A fused runtime can execute one kernel:

```text
kernel: w = log(exp(x) + 1)
```

The arithmetic is similar. The memory behavior is very different.

### Why Fusion Matters

Modern accelerators are fast at arithmetic but limited by memory movement and launch overhead.

If each operation writes a full tensor to memory and the next operation reads it back, the program performs unnecessary global memory traffic.

For the expression:

```text
w = log(exp(x) + 1)
```

the unfused version writes and reads intermediate tensors:

```text
x -> read
y -> write
y -> read
z -> write
z -> read
w -> write
```

The fused version can keep intermediate values in registers:

```text
x -> read
compute exp(x), add 1, log
w -> write
```

This reduces memory bandwidth use, temporary allocation, and kernel launch overhead.

### Fusion and AD

Automatic differentiation often creates many small operations.

A forward-mode transformation may turn:

```text
z = x * y
```

into:

```text
z  = x * y
dz = dx * y + x * dy
```

That derivative expression contains several primitive operations. Without fusion, the transformed program may allocate temporary tensors for each multiply and add.

Reverse mode has the same problem. A backward rule such as:

```text
bar_x += y * bar_z
bar_y += x * bar_z
```

may become multiple kernels unless fused.

AD therefore increases the importance of fusion. Differentiation expands the computation graph. Fusion contracts it back into efficient execution units.

### Elementwise Fusion

The simplest case is elementwise fusion.

Operations such as:

```text
add
mul
neg
exp
log
sin
cos
relu
sigmoid
tanh
```

operate independently on each element.

A chain of elementwise operations can usually be fused safely:

```text
y = tanh(a * x + b)
```

Instead of producing intermediates for `a * x`, then `a * x + b`, then `tanh`, one kernel computes the final result element by element.

Elementwise fusion is common because it has simple semantics and predictable memory access.

### Producer-Consumer Fusion

Fusion is often expressed as a producer-consumer relation.

If operation `A` produces a value used only by operation `B`, then `A` may be fused into `B`.

```text
A -> B
```

Example:

```text
t = x * y
z = t + b
```

If `t` has no other users, the compiler can fuse:

```text
z = x * y + b
```

If `t` has many users, fusion becomes less obvious:

```text
t = x * y
z1 = t + b
z2 = t - c
```

The compiler may duplicate the computation of `t`, or it may materialize `t` once. The choice depends on cost.

### Fusion and Materialization

Fusion avoids materializing intermediate tensors. Materialization means writing an intermediate value to memory as a standalone array.

Avoiding materialization is usually good, but not always.

Materialization may be useful when:

```text
the value is reused many times
the producer is expensive
the value is needed by the backward pass
the value crosses a device boundary
the value is required for debugging or checkpointing
```

A fusion pass must decide which intermediates disappear and which remain as explicit buffers.

This decision interacts directly with reverse-mode memory planning.

### Fusion in the Forward Pass

Consider:

```text
a = W @ x
b = a + bias
c = relu(b)
```

A compiler may fuse the bias add and ReLU:

```text
c = relu((W @ x) + bias)
```

The matrix multiplication itself may remain a separate kernel because it has a specialized implementation.

The fused boundary becomes:

```text
matmul kernel
fused elementwise kernel
```

This is common. Large structured operations such as matrix multiplication, convolution, FFT, and linear solve often remain separate. Surrounding elementwise operations may fuse around them.

### Fusion in the Backward Pass

The backward pass has its own fusion opportunities.

For ReLU:

```text
c = relu(b)
```

the backward rule is:

```text
bar_b = bar_c * (b > 0)
```

If `b` came from bias addition:

```text
b = a + bias
```

then the backward pass includes:

```text
bar_a += bar_b
bar_bias += reduce_sum(bar_b)
```

A compiler may fuse the mask computation with the propagation to `bar_a`, while keeping the reduction for `bar_bias` separate.

This illustrates a common pattern:

```text
elementwise backward operations fuse easily
reductions and contractions define fusion boundaries
```

### Reductions as Fusion Boundaries

Reductions combine many input elements into fewer output elements.

Examples:

```text
sum
mean
max
norm
softmax denominator
gradient accumulation
```

Reductions have different parallel structure from elementwise maps. They often require synchronization, shared memory, or tree reductions.

Fusion across reductions is possible but more constrained.

For example:

```text
y = sum(exp(x))
```

The `exp` can often be fused into the reduction:

```text
sum_exp_kernel:
    accumulate exp(x[i])
```

But if the result is used later for normalization:

```text
s = sum(exp(x))
p = exp(x) / s
```

a naive implementation may compute `exp(x)` twice or store it. A better fused softmax kernel computes the expression using a numerically stable multi-stage algorithm.

### Pattern Fusion

Some fusion is not local producer-consumer fusion. It is pattern-based.

A compiler may recognize:

```text
exp
reduce_sum
divide
```

as softmax.

It may replace the subgraph with a specialized softmax kernel.

Similarly, it may recognize:

```text
matmul
add bias
activation
```

as a fused dense layer.

Pattern fusion matters for AD because derivative graphs often contain recognizable patterns too:

```text
softmax
cross_entropy
```

has a simpler and more stable combined derivative than differentiating the two operations independently.

### Fusion and Numerical Stability

Fusion can change floating-point behavior.

The expression:

```text
(a + b) + c
```

may produce a different result from:

```text
a + (b + c)
```

because floating-point addition is not associative.

Fusion may also change when intermediate values are rounded, whether operations use fused multiply-add, and whether temporary values are stored in lower precision.

For most ML workloads, these differences are acceptable. For scientific computing, finance, or reproducibility-sensitive workloads, the compiler may need strict modes that limit fusion.

### Fusion and Saved Values

Reverse mode needs some primal values during backward.

Suppose the forward graph contains:

```text
a = exp(x)
b = a + 1
y = log(b)
```

The backward pass may need `a` or `b`.

If fusion removes `a` and `b` as materialized tensors, the system must choose one of three strategies:

```text
save selected intermediates anyway
recompute them during backward
rewrite backward rules to use available outputs
```

For example, if `y = log(b)`, then `b = exp(y)`. A backward rule may use `y` to recover `b`, but this may be less stable or more expensive.

Fusion therefore cannot be designed independently from AD. It must know which values are needed later.

### Fusion and Checkpointing

Fusion changes checkpoint boundaries.

If a long chain is fused into one kernel, the compiler may save only the input and output of the fused region. During backward, it may recompute internal intermediates inside the fused backward kernel.

This can reduce memory. It can also increase compute.

For large models, the planner may choose:

```text
fuse aggressively inside checkpoint segments
materialize values at checkpoint boundaries
recompute fused internals during backward
```

This creates a three-way tradeoff among memory, arithmetic, and compile complexity.

### Fusion and Custom Kernels

A fused region often becomes a generated custom kernel.

The compiler emits code such as:

```text
for each element i:
    tmp1 = x[i] * scale
    tmp2 = tmp1 + bias[i]
    out[i] = max(tmp2, 0)
```

For GPUs, this becomes a CUDA, HIP, Triton, MLIR, XLA, or vendor-specific kernel.

The compiler must choose:

```text
thread layout
vectorization strategy
memory coalescing
shared memory use
register allocation
tiling
unrolling
```

Fusion improves memory traffic, but large fused kernels may use too many registers or reduce occupancy. Bigger kernels are not always faster.

### When Fusion Hurts

Fusion can hurt performance when it creates a kernel that is too large or poorly balanced.

Possible problems include:

| Problem | Effect |
|---|---|
| Register pressure | Fewer active threads |
| Code size growth | Longer compile time |
| Duplicated computation | More arithmetic |
| Reduced library use | Loses tuned vendor kernels |
| Poor cache locality | Worse memory behavior |
| Harder scheduling | Less overlap between operations |
| Numerical changes | Different floating-point results |

A practical compiler uses cost models and heuristics rather than fusing everything.

### Fusion Boundaries

Common fusion boundaries include:

```text
matrix multiplication
convolution
large reductions
sort
top-k
gather and scatter
random sampling
I/O
communication
device transfer
opaque custom calls
effectful operations
```

These operations may still participate in specialized fusion, but generic fusion usually stops at them.

For example, an all-reduce in distributed training is a communication boundary. The compiler may overlap computation with communication, but it cannot simply merge the communication into an arbitrary elementwise kernel.

### Fusion in Graph IRs

Graph IRs make fusion natural.

The compiler can inspect a subgraph:

```text
x -> mul -> add -> relu -> y
```

and replace it with a fused node:

```text
x -> fused_mul_add_relu -> y
```

The fused node may later lower to a generated kernel.

After AD transformation, the graph may contain both primal and derivative nodes. Fusion may operate separately on forward and backward graphs, or jointly on the combined training graph.

### Fusion in SSA IRs

SSA IRs support fusion through loop transformations.

Example:

```text
for i:
    t[i] = x[i] * scale

for i:
    y[i] = t[i] + bias[i]
```

Loop fusion combines the loops:

```text
for i:
    t = x[i] * scale
    y[i] = t + bias[i]
```

The temporary array `t` disappears.

At lower levels, kernel fusion is often loop fusion plus memory planning plus code generation.

### AD-Aware Fusion

An AD-aware fusion pass knows about derivative structure.

It can fuse:

```text
primal operation
saved residual computation
local backward rule
adjoint accumulation
```

when profitable.

For example, a fused activation kernel may produce both:

```text
output activation
compact mask for backward
```

Instead of saving the full input tensor, it may save a bitmask.

For ReLU, the forward pass can store one bit per element indicating whether the input was positive. The backward pass uses that mask. This trades a full activation tensor for compact metadata.

### Fusion and Layout

Fusion also interacts with memory layout.

If one operation prefers row-major layout and another prefers tiled layout, fusing them may require a layout decision.

A compiler may propagate layout through a fused region to avoid transposes.

Example:

```text
transpose
matmul
add
relu
```

The compiler may fold the transpose into the matmul layout choice instead of materializing a transposed tensor.

Layout-aware fusion can have large performance effects.

### Fusion and Batching

Batching transformations, such as vectorization over examples, can expose new fusion opportunities.

A scalar expression applied over a batch becomes an elementwise tensor expression. This can be fused into a single batched kernel.

Likewise, AD plus batching may produce large repeated derivative computations. Fusion helps collapse those repetitions into efficient kernels.

The ordering of transformations matters:

```text
fuse after batching
fuse after AD
fuse after layout selection
```

Each order exposes different opportunities.

### Fusion and Compilation Time

Fusion increases compiler work.

The compiler must find candidate regions, check legality, estimate cost, generate code, and optimize the generated kernel.

For interactive workloads, compilation latency matters. A system may choose conservative fusion to keep compile time low.

For long-running training jobs or simulations, aggressive fusion may pay off.

This creates another staging decision:

```text
fast compile, acceptable runtime
or
slow compile, better runtime
```

### Design Questions

A fusion system for AD should answer:

```text
Which operations are fusible?
Which values must remain materialized for backward?
Can recomputation replace saved intermediates?
How are reductions handled?
How are effectful operations handled?
How does fusion affect numerical reproducibility?
What cost model prevents harmful fusion?
How does fusion interact with layout and device placement?
How are fused derivative kernels generated?
```

These questions belong in the compiler design, not in user code.

### Summary

Kernel fusion combines multiple operations into larger execution units. In AD systems, it is especially important because differentiation expands programs with many additional primitive operations.

Fusion reduces memory traffic, temporary allocation, and kernel launch overhead. It enables derivative code to run closer to the cost of the original primal computation.

The main challenge is legality and profitability. Fusion must respect control flow, effects, saved values, layout, numerical behavior, and backend constraints. A good AD compiler fuses aggressively where memory traffic dominates, but preserves boundaries where materialization, stability, or specialized kernels are better.

