A saliency map is a visualization that assigns an importance score to each part of an input.
A saliency map is a visualization that assigns an importance score to each part of an input. For an image model, the saliency map usually assigns a score to each pixel or image region. For a text model, it may assign a score to each token. The goal is to estimate which parts of the input most influenced the model’s prediction.
Saliency methods are often used for model inspection. They can reveal whether a classifier looks at the object, the background, a watermark, a border artifact, or another unintended feature. They do not prove that a model reasons correctly. They are diagnostic tools, not complete explanations.
Basic Idea
Let a model produce a score for class :
A saliency method asks how sensitive this score is to each input component. The simplest form uses the gradient of the class score with respect to the input:
Each entry of measures how much a small change in the corresponding input component would change the class score.
For images, may have shape
[B, C, H, W]and the saliency tensor has the same shape. It is often reduced across color channels to produce a heatmap of shape
[B, H, W]Vanilla Gradient Saliency
Vanilla gradient saliency is the simplest saliency method.
Given an input image , a target class , and a model , compute the gradient of the class score with respect to the input:
In PyTorch:
import torch
def vanilla_saliency(model, x, target_class=None):
model.eval()
x = x.clone().detach().requires_grad_(True)
logits = model(x)
if target_class is None:
target_class = logits.argmax(dim=1)
scores = logits.gather(1, target_class.view(-1, 1)).sum()
model.zero_grad(set_to_none=True)
scores.backward()
saliency = x.grad.abs()
# Reduce RGB channels to one heatmap per image.
saliency = saliency.max(dim=1).values
return saliency.detach()This function returns a heatmap with shape [B, H, W].
Vanilla saliency is cheap because it requires only one backward pass. Its weakness is noise. Gradients may highlight small high-frequency patterns rather than semantically meaningful regions.
Class Scores Versus Loss Gradients
A saliency map should be clear about which quantity is being differentiated.
One option is to differentiate the class score:
This asks: which input components increase or decrease the score for class ?
Another option is to differentiate the loss:
This asks: which input components would most change the training objective?
For interpretation, class-score gradients are usually more direct. For adversarial attacks, loss gradients are usually more useful. The same gradient machinery is used in both cases, but the objective is different.
Positive and Negative Evidence
Taking the absolute value of the gradient removes the sign. This shows sensitivity, but it loses direction.
A signed gradient keeps direction:
Positive values indicate that increasing the input component would increase the class score. Negative values indicate that increasing it would decrease the class score.
For images, signed saliency can be harder to interpret because pixels have channels, normalization, and nonlinear preprocessing. For tabular data, signed gradients may be more meaningful.
SmoothGrad
SmoothGrad reduces noise by averaging saliency maps over noisy copies of the same input.
Let be random noise. SmoothGrad computes:
The intuition is that meaningful regions should remain important under small input noise, while random gradient noise should average out.
In PyTorch:
def smoothgrad_saliency(model, x, target_class=None, samples=32, noise_std=0.1):
model.eval()
saliency_sum = torch.zeros(
x.shape[0], x.shape[2], x.shape[3],
device=x.device,
)
for _ in range(samples):
noise = torch.randn_like(x) * noise_std
x_noisy = torch.clamp(x + noise, 0.0, 1.0)
saliency = vanilla_saliency(model, x_noisy, target_class)
saliency_sum += saliency
return saliency_sum / samplesSmoothGrad is more expensive than vanilla saliency because it requires multiple backward passes. It often produces cleaner visualizations.
Integrated Gradients
Integrated gradients compare the input with a baseline input. The baseline is usually a black image, zero vector, padding token, or another neutral reference.
For input , baseline , and class score , integrated gradients are defined as:
The method accumulates gradients along the straight path from the baseline to the real input.
In practice, the integral is approximated with a finite sum:
PyTorch implementation:
def integrated_gradients(model, x, target_class=None, baseline=None, steps=50):
model.eval()
if baseline is None:
baseline = torch.zeros_like(x)
if target_class is None:
with torch.no_grad():
logits = model(x)
target_class = logits.argmax(dim=1)
total_grad = torch.zeros_like(x)
for k in range(1, steps + 1):
alpha = k / steps
x_interp = baseline + alpha * (x - baseline)
x_interp = x_interp.clone().detach().requires_grad_(True)
logits = model(x_interp)
scores = logits.gather(1, target_class.view(-1, 1)).sum()
model.zero_grad(set_to_none=True)
scores.backward()
total_grad += x_interp.grad
avg_grad = total_grad / steps
attribution = (x - baseline) * avg_grad
# For images, reduce channels.
heatmap = attribution.abs().sum(dim=1)
return heatmap.detach()Integrated gradients are often more stable than vanilla gradients. The choice of baseline matters. A poor baseline can produce misleading attributions.
Grad-CAM
Gradient-weighted Class Activation Mapping, or Grad-CAM, is commonly used for convolutional networks. It produces a coarse heatmap over spatial regions of a feature map.
Let be feature map in a convolutional layer. Let be the score for class . Grad-CAM computes channel weights:
Then it forms a weighted sum of feature maps:
The result is resized to the input image size.
Grad-CAM is useful because it highlights larger semantic regions rather than individual pixels. It is less precise spatially, but often easier to interpret.
A minimal PyTorch implementation requires hooks:
class GradCAM:
def __init__(self, model, target_layer):
self.model = model.eval()
self.target_layer = target_layer
self.activations = None
self.gradients = None
self.forward_handle = target_layer.register_forward_hook(self._save_activation)
self.backward_handle = target_layer.register_full_backward_hook(self._save_gradient)
def _save_activation(self, module, inputs, output):
self.activations = output
def _save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0]
def remove_hooks(self):
self.forward_handle.remove()
self.backward_handle.remove()
def __call__(self, x, target_class=None):
logits = self.model(x)
if target_class is None:
target_class = logits.argmax(dim=1)
score = logits.gather(1, target_class.view(-1, 1)).sum()
self.model.zero_grad(set_to_none=True)
score.backward()
weights = self.gradients.mean(dim=(2, 3), keepdim=True)
cam = (weights * self.activations).sum(dim=1)
cam = torch.relu(cam)
cam_min = cam.amin(dim=(1, 2), keepdim=True)
cam_max = cam.amax(dim=(1, 2), keepdim=True)
cam = (cam - cam_min) / (cam_max - cam_min + 1e-8)
return cam.detach()For a ResNet-like model, the target layer is often the last convolutional block.
Occlusion Sensitivity
Occlusion sensitivity is a perturbation-based saliency method. Instead of using gradients, it masks parts of the input and observes how the model output changes.
For an image, we slide a patch over the image and replace that patch with a baseline value. If masking a region greatly reduces the score for class , that region is considered important.
The method is easy to understand but computationally expensive. It requires many forward passes.
def occlusion_saliency(model, x, target_class=None, patch_size=16, stride=8):
model.eval()
with torch.no_grad():
logits = model(x)
if target_class is None:
target_class = logits.argmax(dim=1)
base_scores = logits.gather(1, target_class.view(-1, 1)).squeeze(1)
B, C, H, W = x.shape
heatmap = torch.zeros(B, H, W, device=x.device)
counts = torch.zeros(B, H, W, device=x.device)
for top in range(0, H, stride):
for left in range(0, W, stride):
bottom = min(top + patch_size, H)
right = min(left + patch_size, W)
x_occ = x.clone()
x_occ[:, :, top:bottom, left:right] = 0.0
with torch.no_grad():
logits_occ = model(x_occ)
scores_occ = logits_occ.gather(1, target_class.view(-1, 1)).squeeze(1)
drop = base_scores - scores_occ
heatmap[:, top:bottom, left:right] += drop.view(B, 1, 1)
counts[:, top:bottom, left:right] += 1
return heatmap / counts.clamp_min(1)Occlusion methods can capture nonlinear effects better than local gradients. Their cost grows quickly with image size and patch resolution.
Saliency for Text Models
For text, the input tokens are discrete. We cannot usually take meaningful gradients with respect to token IDs. Instead, saliency is computed with respect to token embeddings.
Suppose token IDs have shape
[B, T]and embeddings have shape
[B, T, D]A saliency score per token can be computed by differentiating the class score with respect to the embedding vectors, then reducing over the embedding dimension:
This gives one score per token.
The exact implementation depends on the model architecture. In transformer libraries, it is often easier to use embedding hooks or pass inputs_embeds directly if the model supports it.
Failure Modes of Saliency Maps
Saliency maps are easy to misuse. They are visual and intuitive, but their meaning is limited.
Common failure modes include:
| Failure mode | Description |
|---|---|
| Gradient noise | Heatmaps may emphasize high-frequency artifacts |
| Saturation | Gradients may be small even for important features |
| Baseline dependence | Integrated gradients depend on reference input |
| Layer dependence | Grad-CAM depends on the chosen layer |
| Visual overinterpretation | A plausible heatmap may not imply correct reasoning |
| Model insensitivity | Some saliency methods change little across different models |
A saliency map answers a narrow question about sensitivity or attribution under a chosen method. It does not reveal the full causal mechanism of the model.
Sanity Checks
A saliency method should pass basic sanity checks.
One test is parameter randomization. If the model weights are randomized, the saliency map should change substantially. If the heatmap remains similar, the method may mostly reflect the input image rather than the learned model.
Another test is label randomization. A model trained on random labels should produce different explanations from a model trained on real labels.
A third test is input perturbation. Removing highly salient regions should reduce the target score more than removing low-saliency regions. If it does not, the saliency map may not be faithful.
Practical Use
Saliency maps are useful when treated as one signal among several. They can help find dataset artifacts, preprocessing bugs, shortcut features, and suspicious model behavior.
For image classifiers, inspect saliency maps across many examples, not only cherry-picked cases. Compare correct and incorrect predictions. Compare clean and shifted data. Compare different model architectures.
For text models, token saliency can show whether a classifier relies on names, punctuation, demographic tokens, prompt artifacts, or label words. This is especially useful for auditing spurious correlations.
For safety-critical systems, saliency should support investigation, not certification. A model with plausible saliency can still fail under distribution shift, adversarial attack, or causal intervention.
Summary
Saliency maps assign importance scores to input components. Vanilla gradients measure local sensitivity. SmoothGrad averages noisy gradients. Integrated gradients attribute predictions relative to a baseline. Grad-CAM highlights spatial regions using convolutional feature maps. Occlusion sensitivity measures the effect of masking input regions.
These methods are practical and widely used, but they are approximate. Their results depend on the chosen objective, layer, baseline, and visualization method. A saliency map should be read as a diagnostic artifact, not as proof that the model understands the task.