Skip to content

Tensor Operations

Tensor operations generalize scalar, vector, and matrix operations to arrays with arbitrary rank. In automatic differentiation, a tensor is usually treated as a typed array...

Tensor operations generalize scalar, vector, and matrix operations to arrays with arbitrary rank. In automatic differentiation, a tensor is usually treated as a typed array with shape, dtype, layout, and device. The mathematical object is an element of a finite-dimensional vector space. The systems object is a strided memory view with rules for indexing and mutation.

A tensor XX with shape

(d1,d2,,dk) (d_1, d_2, \ldots, d_k)

has entries

Xi1,i2,,ik. X_{i_1,i_2,\ldots,i_k}.

The number kk is the rank of the tensor. The total number of elements is

X=r=1kdr. |X| = \prod_{r=1}^{k} d_r.

A tensor operation is a function between shaped spaces:

f:Rd1××dkRe1××el. f : \mathbb{R}^{d_1 \times \cdots \times d_k} \to \mathbb{R}^{e_1 \times \cdots \times e_l}.

Its derivative is still a linear map. As with matrices, AD systems rarely form the full derivative tensor. They compute how perturbations move forward and how adjoints move backward.

Elementwise Operations

An elementwise operation applies the same scalar function to every entry:

Y=ϕ(X), Y = \phi(X), Yi=ϕ(Xi). Y_i = \phi(X_i).

The forward differential is

dYi=ϕ(Xi)dXi. dY_i = \phi'(X_i)dX_i.

So

dY=ϕ(X)dX. dY = \phi'(X) \odot dX.

Here \odot means elementwise multiplication.

The reverse rule is

Xˉ+=Yˉϕ(X). \bar{X} \mathrel{+}= \bar{Y} \odot \phi'(X).

Examples:

OperationForward RuleReverse Rule
Y=sinXY = \sin XdY=cos(X)dXdY = \cos(X) \odot dXXˉ+=Yˉcos(X)\bar{X} \mathrel{+}= \bar{Y} \odot \cos(X)
Y=expXY = \exp XdY=exp(X)dXdY = \exp(X) \odot dXXˉ+=Yˉexp(X)\bar{X} \mathrel{+}= \bar{Y} \odot \exp(X)
Y=X2Y = X^2dY=2XdXdY = 2X \odot dXXˉ+=2XYˉ\bar{X} \mathrel{+}= 2X \odot \bar{Y}
Y=logXY = \log XdY=dXXdY = dX \oslash XXˉ+=YˉX\bar{X} \mathrel{+}= \bar{Y} \oslash X

Elementwise rules are local and shape-preserving. They are among the simplest AD primitives.

Reductions

A reduction removes one or more axes.

For example, if

Y=sum(X,axis=r), Y = \operatorname{sum}(X, \text{axis}=r),

then YY has the same shape as XX, except axis rr is removed or kept with size 11, depending on the API.

For a full sum,

y=iXi. y = \sum_i X_i.

The forward differential is

dy=idXi. dy = \sum_i dX_i.

The reverse rule broadcasts the scalar adjoint back to every input element:

Xˉi+=yˉ. \bar{X}_i \mathrel{+}= \bar{y}.

So reduction and broadcasting are adjoints of each other.

For a mean reduction over nn elements,

y=1niXi. y = \frac{1}{n}\sum_i X_i.

The reverse rule is

Xˉi+=yˉn. \bar{X}_i \mathrel{+}= \frac{\bar{y}}{n}.

For a maximum reduction, the rule is piecewise:

y=maxiXi. y = \max_i X_i.

If a unique index jj attains the maximum, then

Xˉj+=yˉ,Xˉi+=0ij. \bar{X}_j \mathrel{+}= \bar{y}, \qquad \bar{X}_i \mathrel{+}= 0 \quad i \ne j.

If several entries tie for the maximum, the derivative is not unique. Libraries choose a subgradient convention, often sending the adjoint to one selected maximum or splitting it across all maxima.

Reshape and View Operations

Reshape changes how the same elements are indexed. It does not change values.

Y=reshape(X). Y = \operatorname{reshape}(X).

The forward rule is

dY=reshape(dX). dY = \operatorname{reshape}(dX).

The reverse rule is

Xˉ+=reshape1(Yˉ). \bar{X} \mathrel{+}= \operatorname{reshape}^{-1}(\bar{Y}).

Transpose and permutation are similar. If

Y=permute(X,π), Y = \operatorname{permute}(X, \pi),

then

dY=permute(dX,π), dY = \operatorname{permute}(dX, \pi),

and the reverse rule uses the inverse permutation:

Xˉ+=permute(Yˉ,π1). \bar{X} \mathrel{+}= \operatorname{permute}(\bar{Y}, \pi^{-1}).

These operations are mathematically cheap. In a runtime, they may create views with new strides rather than copy data. AD must respect aliasing: two different views may refer to the same storage.

Slicing and Indexing

Slicing selects part of a tensor:

Y=X[I]. Y = X[I].

The forward rule is

dY=dX[I]. dY = dX[I].

The reverse rule scatters the output adjoint back into the selected region:

Xˉ[I]+=Yˉ. \bar{X}[I] \mathrel{+}= \bar{Y}.

For repeated indices, adjoints must accumulate. For example,

Y=X[[0,0,2]] Y = X[[0,0,2]]

means the first input element is used twice. The reverse pass must add both contributions to Xˉ0\bar{X}_0.

Indexing has a sharp distinction:

QuantityDifferentiability
Indexed valuesDifferentiable
Integer indicesUsually not differentiable
Soft indices or interpolation weightsDifferentiable

Integer index selection is discrete. AD can propagate gradients through the selected values, but not through the selection decision itself without relaxation or specialized estimators.

Concatenation and Split

Concatenation joins tensors along an axis:

Y=concat(X1,X2,,Xn). Y = \operatorname{concat}(X_1, X_2, \ldots, X_n).

The forward differential is

dY=concat(dX1,dX2,,dXn). dY = \operatorname{concat}(dX_1, dX_2, \ldots, dX_n).

The reverse rule splits the output adjoint into matching pieces:

Xˉr+=slicer(Yˉ). \bar{X}_r \mathrel{+}= \operatorname{slice}_r(\bar{Y}).

Split is the adjoint pattern in reverse: the forward pass slices, and the reverse pass concatenates or scatters.

Broadcasting

Broadcasting expands a smaller tensor across one or more axes.

For example,

Yij=Xi Y_{ij} = X_i

broadcasts XRmX \in \mathbb{R}^m to YRm×nY \in \mathbb{R}^{m \times n}.

The forward differential is

dYij=dXi. dY_{ij} = dX_i.

The reverse rule sums over the broadcasted axis:

Xˉi+=j=1nYˉij. \bar{X}_i \mathrel{+}= \sum_{j=1}^{n} \bar{Y}_{ij}.

This is one of the most common places where hand-written gradients are wrong. The reverse of broadcasting is reduction over exactly the axes introduced by broadcasting.

More generally:

Xˉ+=reduce_sum(Yˉ,broadcasted axes). \bar{X} \mathrel{+}= \operatorname{reduce\_sum}(\bar{Y}, \text{broadcasted axes}).

Tensor Contraction

Tensor contraction generalizes matrix multiplication. It multiplies tensors and sums over shared axes.

Matrix multiplication is

Cij=kAikBkj. C_{ij} = \sum_k A_{ik}B_{kj}.

A general contraction has the form

Yα,γ=βAα,βBβ,γ. Y_{\alpha,\gamma} = \sum_{\beta} A_{\alpha,\beta}B_{\beta,\gamma}.

Here α\alpha, β\beta, and γ\gamma may each represent multiple axes.

The forward differential is

dYα,γ=βdAα,βBβ,γ+βAα,βdBβ,γ. dY_{\alpha,\gamma} = \sum_{\beta} dA_{\alpha,\beta}B_{\beta,\gamma} + \sum_{\beta} A_{\alpha,\beta}dB_{\beta,\gamma}.

The reverse rules are contractions with the complementary tensor:

Aˉα,β+=γYˉα,γBβ,γ, \bar{A}_{\alpha,\beta} \mathrel{+}= \sum_{\gamma} \bar{Y}_{\alpha,\gamma}B_{\beta,\gamma}, Bˉβ,γ+=αAα,βYˉα,γ. \bar{B}_{\beta,\gamma} \mathrel{+}= \sum_{\alpha} A_{\alpha,\beta}\bar{Y}_{\alpha,\gamma}.

This is the abstract rule behind matrix multiplication, batched matrix multiplication, attention score computation, convolution lowered to matrix multiplication, and many einsum expressions.

Einsum Notation

Einstein summation notation gives a compact language for tensor contractions.

For example,

Cij=AikBkj C_{ij} = A_{ik}B_{kj}

corresponds to matrix multiplication.

A batched matrix multiplication can be written as

Cbij=AbikBbkj. C_{bij} = A_{bik}B_{bkj}.

The repeated index kk is summed. Indices appearing in the output are preserved.

For an einsum expression, reverse-mode rules can be derived by replacing one input with the output adjoint and solving for the missing input index pattern.

If

Yij=AikBkj, Y_{ij} = A_{ik}B_{kj},

then

Aˉik=jYˉijBkj, \bar{A}_{ik} = \sum_j \bar{Y}_{ij}B_{kj}, Bˉkj=iAikYˉij. \bar{B}_{kj} = \sum_i A_{ik}\bar{Y}_{ij}.

Einsum is valuable because it makes shape relationships explicit. It also gives compiler backends a compact representation for optimization.

Norms

Norms reduce tensors to scalars or lower-rank tensors.

For the squared Euclidean norm,

y=XF2=iXi2, y = \|X\|_F^2 = \sum_i X_i^2,

the differential is

dy=2iXidXi. dy = 2\sum_i X_i dX_i.

So

Xˉ+=2yˉX. \bar{X} \mathrel{+}= 2\bar{y}X.

For the Euclidean norm,

y=XF=iXi2, y = \|X\|_F = \sqrt{\sum_i X_i^2},

when X0X \ne 0,

Xˉ+=yˉXXF. \bar{X} \mathrel{+}= \bar{y}\frac{X}{\|X\|_F}.

At X=0X=0, the norm is not differentiable. Implementations either return a chosen subgradient, return zero, or rely on user-added numerical stabilization such as

ϵ+iXi2. \sqrt{\epsilon + \sum_i X_i^2}.

Softmax

Softmax is a tensor operation usually applied along one axis. For a vector xx,

yi=exijexj. y_i = \frac{e^{x_i}}{\sum_j e^{x_j}}.

The differential is

dyi=yi(dxijyjdxj). dy_i = y_i\left(dx_i - \sum_j y_j dx_j\right).

Given an output adjoint yˉ\bar{y}, the reverse rule is

xˉi+=yi(yˉijyˉjyj). \bar{x}_i \mathrel{+}= y_i \left( \bar{y}_i - \sum_j \bar{y}_j y_j \right).

In vector form:

xˉ+=y(yˉyˉ,y1). \bar{x} \mathrel{+}= y \odot \left( \bar{y} - \langle \bar{y}, y\rangle \mathbf{1} \right).

Softmax is rarely implemented as a naive primitive. Production systems use numerically stable forms, usually subtracting the maximum value before exponentiation. For cross-entropy loss, systems often fuse softmax and log-loss to avoid unstable intermediate values.

Tensor Operations and Memory Layout

A tensor has more than a shape. It also has a layout. A strided tensor view is described by:

offset,shape,strides. \text{offset}, \qquad \text{shape}, \qquad \text{strides}.

The address of an element is computed as

addr(i1,,ik)=base+offset+r=1kirsr. \text{addr}(i_1,\ldots,i_k) = \text{base} + \text{offset} + \sum_{r=1}^{k} i_r s_r.

Here srs_r is the stride for axis rr.

Two tensors may have the same shape but different memory layouts. Transpose can often be represented by changing strides. Slice can often be represented by changing offset and shape. Expand can sometimes be represented using a stride of zero.

This matters for AD because the reverse pass must accumulate into storage correctly. If multiple output elements refer to the same input storage location, adjoints must be summed, not overwritten.

Tensor Primitive Rules

PrimitiveForward DifferentialReverse Rule
Y=ϕ(X)Y = \phi(X)dY=ϕ(X)dXdY = \phi'(X)\odot dXXˉ+=Yˉϕ(X)\bar{X} \mathrel{+}= \bar{Y}\odot \phi'(X)
Y=X+ZY = X + ZdY=dX+dZdY = dX + dZXˉ+=Yˉ, Zˉ+=Yˉ\bar{X} \mathrel{+}= \bar{Y},\ \bar{Z} \mathrel{+}= \bar{Y}
Y=XZY = X \odot ZdY=dXZ+XdZdY = dX\odot Z + X\odot dZXˉ+=YˉZ, Zˉ+=YˉX\bar{X} \mathrel{+}= \bar{Y}\odot Z,\ \bar{Z} \mathrel{+}= \bar{Y}\odot X
Y=sum(X)Y = \operatorname{sum}(X)dY=sum(dX)dY = \operatorname{sum}(dX)Xˉ+=broadcast(Yˉ)\bar{X} \mathrel{+}= \operatorname{broadcast}(\bar{Y})
Y=reshape(X)Y = \operatorname{reshape}(X)dY=reshape(dX)dY = \operatorname{reshape}(dX)Xˉ+=reshape1(Yˉ)\bar{X} \mathrel{+}= \operatorname{reshape}^{-1}(\bar{Y})
Y=X[I]Y = X[I]dY=dX[I]dY = dX[I]Xˉ[I]+=Yˉ\bar{X}[I] \mathrel{+}= \bar{Y}
Y=concat(X1,,Xn)Y = \operatorname{concat}(X_1,\ldots,X_n)dY=concat(dX1,,dXn)dY = \operatorname{concat}(dX_1,\ldots,dX_n)Xˉr+=slicer(Yˉ)\bar{X}_r \mathrel{+}= \operatorname{slice}_r(\bar{Y})
Y=broadcast(X)Y = \operatorname{broadcast}(X)dY=broadcast(dX)dY = \operatorname{broadcast}(dX)Xˉ+=reduce_sum(Yˉ)\bar{X} \mathrel{+}= \operatorname{reduce\_sum}(\bar{Y})
Y=einsum()Y = \operatorname{einsum}(\cdots)Contract input differentialsContract output adjoint with complementary inputs

Implementation View

A tensor AD engine typically represents each primitive with:

op:
  inputs: Tensor[]
  output: Tensor
  forward: compute output values
  jvp: compute output tangent from input tangents
  vjp: compute input adjoints from output adjoint

The reverse rule for a primitive must handle:

shape
dtype
layout
broadcast axes
reduction axes
aliasing
device placement
accumulation

For example, the reverse rule for broadcast addition must know which axes were broadcast during the forward pass. Without that metadata, it cannot reduce the adjoint correctly.

Practical Rule

For tensor operations, the safest derivation process is:

  1. Write the operation with explicit indices.
  2. Differentiate the indexed equation.
  3. Attach an output adjoint.
  4. Sum over output indices.
  5. Rearrange terms so each input differential is isolated.
  6. Read off the reverse rule.
  7. Check shapes and broadcast axes.

Tensor AD is mostly disciplined bookkeeping. The calculus is local. The difficulty is preserving the exact semantics of shape, indexing, layout, and accumulation.