Indexing and slicing select parts of a tensor. These operations are used constantly in PyTorch: selecting batches, cropping images, extracting token positions, applying masks, gathering logits, and rearranging model
Indexing and slicing select parts of a tensor. These operations are used constantly in PyTorch: selecting batches, cropping images, extracting token positions, applying masks, gathering logits, and rearranging model outputs.
A tensor operation may either create a view or a copy. A view shares storage with the original tensor. A copy owns separate storage. This distinction matters for memory use, performance, and mutation.
Basic Indexing
A tensor entry is selected by giving one index per axis.
import torch
X = torch.tensor([
[10, 11, 12],
[20, 21, 22],
[30, 31, 32],
])
print(X[0, 0]) # tensor(10)
print(X[1, 2]) # tensor(22)PyTorch uses zero-based indexing. The first row has index 0. The second row has index 1.
For a matrix , the entry at row , column is
In PyTorch this is written as:
X[i, j]Indexing Rows and Columns
Selecting one row removes the row axis.
X = torch.tensor([
[10, 11, 12],
[20, 21, 22],
[30, 31, 32],
])
row = X[1]
print(row)
print(row.shape)Output:
tensor([20, 21, 22])
torch.Size([3])Selecting one column uses : to keep all rows:
col = X[:, 1]
print(col)
print(col.shape)Output:
tensor([11, 21, 31])
torch.Size([3])The colon means “select everything along this axis.”
Slicing Ranges
A slice selects a range of indices.
X = torch.arange(10)
print(X[2:7])Output:
tensor([2, 3, 4, 5, 6])The start index is included. The stop index is excluded.
General form:
start:stop:stepExamples:
x = torch.arange(10)
print(x[:5]) # first five entries
print(x[5:]) # entries from index 5 onward
print(x[::2]) # every second entry
print(x[::-1]) # may require alternatives in PyTorch versionsFor reverse order, use torch.flip:
x = torch.arange(10)
rev = torch.flip(x, dims=[0])
print(rev)Slicing Higher-Rank Tensors
For a 4D image batch:
The PyTorch shape is:
[B, C, H, W]Example:
X = torch.randn(32, 3, 224, 224)Select the first image:
img = X[0]
print(img.shape) # torch.Size([3, 224, 224])Select all images, first channel:
red = X[:, 0, :, :]
print(red.shape) # torch.Size([32, 224, 224])Crop the center region:
crop = X[:, :, 56:168, 56:168]
print(crop.shape) # torch.Size([32, 3, 112, 112])Slicing is the natural way to express spatial crops, token windows, and feature subsets.
Keeping Dimensions
Integer indexing removes an axis. Slicing with a range preserves it.
X = torch.randn(32, 3, 224, 224)
a = X[0]
b = X[0:1]
print(a.shape) # torch.Size([3, 224, 224])
print(b.shape) # torch.Size([1, 3, 224, 224])X[0] selects one image and removes the batch axis. X[0:1] selects a batch containing one image and keeps the batch axis.
This distinction matters because neural network layers usually expect a batch axis.
Ellipsis Indexing
The ellipsis ... means “all omitted axes.”
X = torch.randn(32, 3, 224, 224)
last_col = X[..., -1]
print(last_col.shape)Output:
torch.Size([32, 3, 224])This is equivalent to:
X[:, :, :, -1]Ellipsis is useful when the number of leading axes may vary.
Example:
def last_feature(x):
return x[..., -1]This function works for [B, D], [B, T, D], or [B, H, W, D].
Boolean Masks
Comparison operations produce Boolean tensors.
x = torch.tensor([-2.0, 0.5, 3.0, -1.0])
mask = x > 0
print(mask)Output:
tensor([False, True, True, False])A Boolean mask can select matching entries:
positive = x[mask]
print(positive)Output:
tensor([0.5000, 3.0000])Boolean indexing returns a flattened selection when applied this way. It creates a new tensor containing only selected entries.
Masks are common in deep learning. For example, in sequence models, a padding mask identifies which tokens should be ignored.
tokens = torch.tensor([
[101, 2054, 2003, 102, 0, 0],
[101, 2129, 2024, 2017, 102, 0],
])
pad_id = 0
padding_mask = tokens == pad_id
print(padding_mask)Masked Assignment
Masks can also modify selected entries.
x = torch.tensor([-2.0, 0.5, 3.0, -1.0])
x[x < 0] = 0.0
print(x)Output:
tensor([0.0000, 0.5000, 3.0000, 0.0000])This operation behaves like an in-place ReLU.
For differentiable model code, prefer functional operations such as torch.where unless mutation is clearly intended:
x = torch.tensor([-2.0, 0.5, 3.0, -1.0])
y = torch.where(x < 0, torch.zeros_like(x), x)
print(y)Advanced Integer Indexing
A tensor of indices can select multiple positions.
x = torch.tensor([10, 20, 30, 40, 50])
idx = torch.tensor([0, 2, 4])
print(x[idx])Output:
tensor([10, 30, 50])For a matrix:
X = torch.tensor([
[10, 11, 12],
[20, 21, 22],
[30, 31, 32],
])
rows = torch.tensor([0, 2])
cols = torch.tensor([1, 2])
print(X[rows, cols])Output:
tensor([11, 32])This selects pairs:
It does not select the rectangular submatrix formed by rows [0, 2] and columns [1, 2].
To select a rectangular submatrix, use torch.ix_-style broadcasting with index tensors:
rows = torch.tensor([0, 2])
cols = torch.tensor([1, 2])
sub = X[rows[:, None], cols]
print(sub)Output:
tensor([[11, 12],
[31, 32]])Gathering Values
torch.gather selects values along a specified axis using an index tensor.
A common use is selecting the logit for the correct class.
logits = torch.tensor([
[2.0, 0.1, -1.0],
[0.3, 1.5, 0.2],
])
labels = torch.tensor([0, 1])The correct class scores are:
selected = logits[torch.arange(logits.shape[0]), labels]
print(selected)Output:
tensor([2.0000, 1.5000])The same idea with gather:
selected = logits.gather(dim=1, index=labels[:, None])
print(selected)Output:
tensor([[2.0000],
[1.5000]])gather is useful when writing vectorized code for classification, sequence losses, and beam search.
Scatter Operations
Scatter operations write values into a tensor at indexed positions.
out = torch.zeros(3, 5)
index = torch.tensor([
[0],
[2],
[4],
])
src = torch.tensor([
[1.0],
[1.0],
[1.0],
])
out.scatter_(dim=1, index=index, src=src)
print(out)Output:
tensor([[1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1.]])This pattern creates one-hot encodings. PyTorch also provides torch.nn.functional.one_hot.
import torch.nn.functional as F
labels = torch.tensor([0, 2, 4])
one_hot = F.one_hot(labels, num_classes=5)
print(one_hot)Views and Shared Storage
Basic slicing usually creates a view. A view shares memory with the original tensor.
x = torch.arange(10)
y = x[2:7]
y[0] = -1
print(x)Output:
tensor([ 0, 1, -1, 3, 4, 5, 6, 7, 8, 9])Changing y changed x.
This behavior avoids unnecessary memory allocation. It also means that mutation through a view can affect the original tensor.
Copies from Advanced Indexing
Advanced indexing usually creates a copy.
x = torch.arange(10)
idx = torch.tensor([2, 3, 4])
y = x[idx]
y[0] = -1
print(x)
print(y)Output:
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([-1, 3, 4])Changing y did not change x because y owns separate storage.
A practical rule:
| Operation type | Usually returns |
|---|---|
| Basic slicing | View |
transpose, permute | View with changed strides |
| Boolean indexing | Copy |
| Integer array indexing | Copy |
clone() | Copy |
contiguous() | Copy if needed |
view, reshape, and clone
view() reshapes a tensor only when the existing memory layout allows it.
x = torch.arange(12)
y = x.view(3, 4)
print(y)reshape() is more flexible. It returns a view when possible and a copy when necessary.
x = torch.arange(12)
y = x.reshape(3, 4)clone() explicitly creates a copy.
x = torch.arange(5)
y = x.clone()
y[0] = -1
print(x)
print(y)Use clone() when you need independent storage.
Contiguity After Permutation
A permuted tensor often becomes noncontiguous.
X = torch.randn(2, 3, 4)
Y = X.permute(0, 2, 1)
print(Y.shape)
print(Y.is_contiguous())Output:
torch.Size([2, 4, 3])
FalseThe tensor Y has a valid shape, but its memory order differs from a standard contiguous layout.
Some operations require contiguous memory. Use:
Yc = Y.contiguous()This creates a contiguous copy.
A common pattern:
Y = X.permute(0, 2, 1).contiguous()
Y = Y.view(2, 12)Without contiguous(), view() may fail.
Indexing in Sequence Models
Suppose a language model returns logits with shape:
where is batch size, is sequence length, and is vocabulary size.
B, T, V = 4, 8, 10000
logits = torch.randn(B, T, V)
targets = torch.randint(0, V, (B, T))To select the logit assigned to each target token:
target_logits = logits.gather(dim=2, index=targets.unsqueeze(-1))
print(target_logits.shape) # torch.Size([4, 8, 1])After removing the final singleton dimension:
target_logits = target_logits.squeeze(-1)
print(target_logits.shape) # torch.Size([4, 8])This is the same indexing principle used in classification, applied at every token position.
Indexing in Attention Masks
Attention masks often have shape:
A transformer attention layer may need shape:
so the mask can broadcast across attention heads and query positions.
B, T = 4, 8
padding_mask = torch.randint(0, 2, (B, T)).bool()
attention_mask = padding_mask[:, None, None, :]
print(attention_mask.shape) # torch.Size([4, 1, 1, 8])The None entries insert singleton axes. This is equivalent to unsqueeze.
attention_mask = padding_mask.unsqueeze(1).unsqueeze(2)Indexing is therefore part of shape engineering in transformer implementations.
Common Mistakes
The most common indexing mistakes are:
| Mistake | Example | Problem |
|---|---|---|
| Removing a needed batch axis | x[0] | Layer may expect [B, D], but receives [D] |
| Using paired indexing accidentally | X[rows, cols] | Selects pairs, not submatrix |
| Confusing view and copy | y = x[idx] | Mutation may not affect original |
Forgetting .contiguous() | x.permute(...).view(...) | Memory layout may be invalid |
| Mask shape mismatch | [B, T] mask for [B, H, T, T] scores | Needs singleton axes for broadcasting |
Good tensor code makes shape changes explicit.
Summary
Indexing selects individual entries. Slicing selects ranges. Boolean masks select entries based on conditions. Integer index tensors allow advanced selection. gather and scatter provide vectorized indexed reading and writing.
Basic slicing usually returns a view that shares storage with the original tensor. Advanced indexing usually returns a copy. reshape, view, permute, contiguous, and clone control how tensor data is interpreted or copied.
Correct indexing is central to PyTorch programming because neural networks often need precise selection across batch, channel, sequence, feature, and vocabulary axes.