Kernel fusion combines several small operations into one larger executable unit.
In tensor programs, a kernel is a unit of work launched on a device such as a CPU, GPU, or TPU. A simple expression may produce many kernels if each primitive operation is executed separately.
y = exp(x)
z = y + 1
w = log(z)A naive runtime may execute three kernels:
kernel 1: y = exp(x)
kernel 2: z = y + 1
kernel 3: w = log(z)A fused runtime can execute one kernel:
kernel: w = log(exp(x) + 1)The arithmetic is similar. The memory behavior is very different.
Why Fusion Matters
Modern accelerators are fast at arithmetic but limited by memory movement and launch overhead.
If each operation writes a full tensor to memory and the next operation reads it back, the program performs unnecessary global memory traffic.
For the expression:
w = log(exp(x) + 1)the unfused version writes and reads intermediate tensors:
x -> read
y -> write
y -> read
z -> write
z -> read
w -> writeThe fused version can keep intermediate values in registers:
x -> read
compute exp(x), add 1, log
w -> writeThis reduces memory bandwidth use, temporary allocation, and kernel launch overhead.
Fusion and AD
Automatic differentiation often creates many small operations.
A forward-mode transformation may turn:
z = x * yinto:
z = x * y
dz = dx * y + x * dyThat derivative expression contains several primitive operations. Without fusion, the transformed program may allocate temporary tensors for each multiply and add.
Reverse mode has the same problem. A backward rule such as:
bar_x += y * bar_z
bar_y += x * bar_zmay become multiple kernels unless fused.
AD therefore increases the importance of fusion. Differentiation expands the computation graph. Fusion contracts it back into efficient execution units.
Elementwise Fusion
The simplest case is elementwise fusion.
Operations such as:
add
mul
neg
exp
log
sin
cos
relu
sigmoid
tanhoperate independently on each element.
A chain of elementwise operations can usually be fused safely:
y = tanh(a * x + b)Instead of producing intermediates for a * x, then a * x + b, then tanh, one kernel computes the final result element by element.
Elementwise fusion is common because it has simple semantics and predictable memory access.
Producer-Consumer Fusion
Fusion is often expressed as a producer-consumer relation.
If operation A produces a value used only by operation B, then A may be fused into B.
A -> BExample:
t = x * y
z = t + bIf t has no other users, the compiler can fuse:
z = x * y + bIf t has many users, fusion becomes less obvious:
t = x * y
z1 = t + b
z2 = t - cThe compiler may duplicate the computation of t, or it may materialize t once. The choice depends on cost.
Fusion and Materialization
Fusion avoids materializing intermediate tensors. Materialization means writing an intermediate value to memory as a standalone array.
Avoiding materialization is usually good, but not always.
Materialization may be useful when:
the value is reused many times
the producer is expensive
the value is needed by the backward pass
the value crosses a device boundary
the value is required for debugging or checkpointingA fusion pass must decide which intermediates disappear and which remain as explicit buffers.
This decision interacts directly with reverse-mode memory planning.
Fusion in the Forward Pass
Consider:
a = W @ x
b = a + bias
c = relu(b)A compiler may fuse the bias add and ReLU:
c = relu((W @ x) + bias)The matrix multiplication itself may remain a separate kernel because it has a specialized implementation.
The fused boundary becomes:
matmul kernel
fused elementwise kernelThis is common. Large structured operations such as matrix multiplication, convolution, FFT, and linear solve often remain separate. Surrounding elementwise operations may fuse around them.
Fusion in the Backward Pass
The backward pass has its own fusion opportunities.
For ReLU:
c = relu(b)the backward rule is:
bar_b = bar_c * (b > 0)If b came from bias addition:
b = a + biasthen the backward pass includes:
bar_a += bar_b
bar_bias += reduce_sum(bar_b)A compiler may fuse the mask computation with the propagation to bar_a, while keeping the reduction for bar_bias separate.
This illustrates a common pattern:
elementwise backward operations fuse easily
reductions and contractions define fusion boundariesReductions as Fusion Boundaries
Reductions combine many input elements into fewer output elements.
Examples:
sum
mean
max
norm
softmax denominator
gradient accumulationReductions have different parallel structure from elementwise maps. They often require synchronization, shared memory, or tree reductions.
Fusion across reductions is possible but more constrained.
For example:
y = sum(exp(x))The exp can often be fused into the reduction:
sum_exp_kernel:
accumulate exp(x[i])But if the result is used later for normalization:
s = sum(exp(x))
p = exp(x) / sa naive implementation may compute exp(x) twice or store it. A better fused softmax kernel computes the expression using a numerically stable multi-stage algorithm.
Pattern Fusion
Some fusion is not local producer-consumer fusion. It is pattern-based.
A compiler may recognize:
exp
reduce_sum
divideas softmax.
It may replace the subgraph with a specialized softmax kernel.
Similarly, it may recognize:
matmul
add bias
activationas a fused dense layer.
Pattern fusion matters for AD because derivative graphs often contain recognizable patterns too:
softmax
cross_entropyhas a simpler and more stable combined derivative than differentiating the two operations independently.
Fusion and Numerical Stability
Fusion can change floating-point behavior.
The expression:
(a + b) + cmay produce a different result from:
a + (b + c)because floating-point addition is not associative.
Fusion may also change when intermediate values are rounded, whether operations use fused multiply-add, and whether temporary values are stored in lower precision.
For most ML workloads, these differences are acceptable. For scientific computing, finance, or reproducibility-sensitive workloads, the compiler may need strict modes that limit fusion.
Fusion and Saved Values
Reverse mode needs some primal values during backward.
Suppose the forward graph contains:
a = exp(x)
b = a + 1
y = log(b)The backward pass may need a or b.
If fusion removes a and b as materialized tensors, the system must choose one of three strategies:
save selected intermediates anyway
recompute them during backward
rewrite backward rules to use available outputsFor example, if y = log(b), then b = exp(y). A backward rule may use y to recover b, but this may be less stable or more expensive.
Fusion therefore cannot be designed independently from AD. It must know which values are needed later.
Fusion and Checkpointing
Fusion changes checkpoint boundaries.
If a long chain is fused into one kernel, the compiler may save only the input and output of the fused region. During backward, it may recompute internal intermediates inside the fused backward kernel.
This can reduce memory. It can also increase compute.
For large models, the planner may choose:
fuse aggressively inside checkpoint segments
materialize values at checkpoint boundaries
recompute fused internals during backwardThis creates a three-way tradeoff among memory, arithmetic, and compile complexity.
Fusion and Custom Kernels
A fused region often becomes a generated custom kernel.
The compiler emits code such as:
for each element i:
tmp1 = x[i] * scale
tmp2 = tmp1 + bias[i]
out[i] = max(tmp2, 0)For GPUs, this becomes a CUDA, HIP, Triton, MLIR, XLA, or vendor-specific kernel.
The compiler must choose:
thread layout
vectorization strategy
memory coalescing
shared memory use
register allocation
tiling
unrollingFusion improves memory traffic, but large fused kernels may use too many registers or reduce occupancy. Bigger kernels are not always faster.
When Fusion Hurts
Fusion can hurt performance when it creates a kernel that is too large or poorly balanced.
Possible problems include:
| Problem | Effect |
|---|---|
| Register pressure | Fewer active threads |
| Code size growth | Longer compile time |
| Duplicated computation | More arithmetic |
| Reduced library use | Loses tuned vendor kernels |
| Poor cache locality | Worse memory behavior |
| Harder scheduling | Less overlap between operations |
| Numerical changes | Different floating-point results |
A practical compiler uses cost models and heuristics rather than fusing everything.
Fusion Boundaries
Common fusion boundaries include:
matrix multiplication
convolution
large reductions
sort
top-k
gather and scatter
random sampling
I/O
communication
device transfer
opaque custom calls
effectful operationsThese operations may still participate in specialized fusion, but generic fusion usually stops at them.
For example, an all-reduce in distributed training is a communication boundary. The compiler may overlap computation with communication, but it cannot simply merge the communication into an arbitrary elementwise kernel.
Fusion in Graph IRs
Graph IRs make fusion natural.
The compiler can inspect a subgraph:
x -> mul -> add -> relu -> yand replace it with a fused node:
x -> fused_mul_add_relu -> yThe fused node may later lower to a generated kernel.
After AD transformation, the graph may contain both primal and derivative nodes. Fusion may operate separately on forward and backward graphs, or jointly on the combined training graph.
Fusion in SSA IRs
SSA IRs support fusion through loop transformations.
Example:
for i:
t[i] = x[i] * scale
for i:
y[i] = t[i] + bias[i]Loop fusion combines the loops:
for i:
t = x[i] * scale
y[i] = t + bias[i]The temporary array t disappears.
At lower levels, kernel fusion is often loop fusion plus memory planning plus code generation.
AD-Aware Fusion
An AD-aware fusion pass knows about derivative structure.
It can fuse:
primal operation
saved residual computation
local backward rule
adjoint accumulationwhen profitable.
For example, a fused activation kernel may produce both:
output activation
compact mask for backwardInstead of saving the full input tensor, it may save a bitmask.
For ReLU, the forward pass can store one bit per element indicating whether the input was positive. The backward pass uses that mask. This trades a full activation tensor for compact metadata.
Fusion and Layout
Fusion also interacts with memory layout.
If one operation prefers row-major layout and another prefers tiled layout, fusing them may require a layout decision.
A compiler may propagate layout through a fused region to avoid transposes.
Example:
transpose
matmul
add
reluThe compiler may fold the transpose into the matmul layout choice instead of materializing a transposed tensor.
Layout-aware fusion can have large performance effects.
Fusion and Batching
Batching transformations, such as vectorization over examples, can expose new fusion opportunities.
A scalar expression applied over a batch becomes an elementwise tensor expression. This can be fused into a single batched kernel.
Likewise, AD plus batching may produce large repeated derivative computations. Fusion helps collapse those repetitions into efficient kernels.
The ordering of transformations matters:
fuse after batching
fuse after AD
fuse after layout selectionEach order exposes different opportunities.
Fusion and Compilation Time
Fusion increases compiler work.
The compiler must find candidate regions, check legality, estimate cost, generate code, and optimize the generated kernel.
For interactive workloads, compilation latency matters. A system may choose conservative fusion to keep compile time low.
For long-running training jobs or simulations, aggressive fusion may pay off.
This creates another staging decision:
fast compile, acceptable runtime
or
slow compile, better runtimeDesign Questions
A fusion system for AD should answer:
Which operations are fusible?
Which values must remain materialized for backward?
Can recomputation replace saved intermediates?
How are reductions handled?
How are effectful operations handled?
How does fusion affect numerical reproducibility?
What cost model prevents harmful fusion?
How does fusion interact with layout and device placement?
How are fused derivative kernels generated?These questions belong in the compiler design, not in user code.
Summary
Kernel fusion combines multiple operations into larger execution units. In AD systems, it is especially important because differentiation expands programs with many additional primitive operations.
Fusion reduces memory traffic, temporary allocation, and kernel launch overhead. It enables derivative code to run closer to the cost of the original primal computation.
The main challenge is legality and profitability. Fusion must respect control flow, effects, saved values, layout, numerical behavior, and backend constraints. A good AD compiler fuses aggressively where memory traffic dominates, but preserves boundaries where materialization, stability, or specialized kernels are better.